diff --git a/.gitmodules b/.gitmodules index d16e9335b..332be6398 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "third-party/cutlass"] path = third-party/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "third-party/fmt"] + path = third-party/fmt + url = https://github.com/fmtlib/fmt.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 658aa7bd3..6f12a9690 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,44 +1,33 @@ # NOTES: current just for CMake-based IDE (e.g. CLion) indexing, the real compilation is done via JIT -# TODO: add CUDA utils' library via CMake cmake_minimum_required(VERSION 3.10) project(deep_gemm LANGUAGES CXX CUDA) - -set(CMAKE_CXX_STANDARD 20) -set(CMAKE_CUDA_STANDARD 20) set(CMAKE_VERBOSE_MAKEFILE ON) -find_package(CUDAToolkit REQUIRED) -find_package(pybind11 REQUIRED) - -file(WRITE ${CMAKE_BINARY_DIR}/test_cuda.cu "extern \"C\" __global__ void testKernel() { }") -execute_process( - COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu - RESULT_VARIABLE NVCC_RESULT - OUTPUT_VARIABLE NVCC_OUTPUT - ERROR_VARIABLE NVCC_ERROR_OUTPUT - WORKING_DIRECTORY ${CMAKE_BINARY_DIR} -) +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi") +set(CUDA_SEPARABLE_COMPILATION ON) +list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG") +list(APPEND CUDA_NVCC_FLAGS "-O3") +list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage") -if (NVCC_RESULT EQUAL "0") - set(NVCC_SUPPORTS_SM90 TRUE) - message(STATUS "NVCC supports SM90") -else() - message(STATUS "NVCC does not support SM90") -endif() +set(USE_SYSTEM_NVTX on) +set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile") +set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}") -if (NVCC_SUPPORTS_SM90) - set(TORCH_CUDA_ARCH_LIST "8.6" CACHE STRING "Add arch tag 90a to NVCC" FORCE) - list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_90a,code=sm_90a") -endif() +find_package(CUDAToolkit REQUIRED) +find_package(pybind11 REQUIRED) find_package(Torch REQUIRED) -include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include) -include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) -link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) + +include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include) +include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) +link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs) -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC") -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -fPIC -DNDEBUG") -set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3 -std=c++17 -DNDEBUG --ptxas-options=--register-usage-level=10") +# The main Python API entrance +pybind11_add_module(deep_gemm_cpp csrc/python_api.cpp) +target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} torch_python cuda) -cuda_add_library(example_gemm STATIC indexing/main.cu) +# Enable kernel code indexing with CMake-based IDEs +cuda_add_library(deep_gemm_indexing_cuda STATIC csrc/indexing/main.cu) diff --git a/README.md b/README.md index 8df722aaf..814099dfe 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,18 @@ # DeepGEMM -DeepGEMM is a library designed for clean and efficient FP8 General Matrix Multiplications (GEMMs) with fine-grained scaling, as proposed in [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3). It supports both normal and Mix-of-Experts (MoE) grouped GEMMs. Written in CUDA, the library has no compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module. +DeepGEMM is a library designed for clean and efficient General Matrix Multiplications (GEMMs). It supports FP8 and BF16 (working in progress) for both normal and Mix-of-Experts (MoE) grouped scenarios. Written in CUDA, the library has no kernel compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module. -Currently, DeepGEMM exclusively supports NVIDIA Hopper tensor cores. To address the imprecise FP8 tensor core accumulation, it employs CUDA-core two-level accumulation (promotion). While it leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only one core kernel function. This makes it a clean and accessible resource for learning Hopper FP8 matrix multiplication and optimization techniques. +DeepGEMM leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only a limited number of core kernel functions. This makes it a clean and accessible resource for learning NVIDIA GPU kernel optimization techniques. Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes. ## News +- 2025.07.20: DeepGEMM now supports both SM90/SM100, and has a full refactor with a low-CPU-overhead JIT CPP module. + - NVRTC and post-compilation SASS optimization are all disabled + - NVRTC will be supported later + - As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported + - Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details - 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details. - 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases). - 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details. @@ -16,57 +21,61 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - [x] More correctness tests for grouped-contiguous layout - [x] Shared memory swizzling for output -- [ ] Larger block size on N (up to 256) - [x] MoE scheduler with TMA multicast compatibility - [x] Fix TMA multicast compatibility for indivisible shapes - [x] Skip useless computation on M - [x] NVRTC as a faster compiler -- [ ] Stolen JIT cache - [ ] Sanitizer for testing - [x] Weight gradient kernels for dense models - [x] Weight gradient kernels for MoE models - [ ] Better `get_best_configs` modeling -- [ ] Utility kernels for MoE models (maybe with [tile-lang](https://github.com/tile-ai/tilelang)) - [ ] CUDA PDL support -- [ ] More scaling granularity support via templates - [ ] Larger TMA multicast size for some shapes - [x] MMA template refactor with CUTLASS -- [ ] Optimizations for power efficiency - [x] Remove shape limitations on N and K - [ ] BF16 kernels - [ ] Split/stream-k optimizations +- [ ] Ampere kernels +- [ ] Polish docs ## Quick start ### Requirements -- Hopper architecture GPUs, `sm_90a` must be supported -- Python 3.8 or above -- CUDA 12.3 or above - - **But we highly recommend 12.8 or above for the best performance** -- PyTorch 2.1 or above -- CUTLASS 3.6 or above (could be cloned by Git submodule) +- NVIDIA SM90 or SM100 architecture GPU +- Python 3.8 or higher +- Compilers with C++20 support +- CUDA Toolkit: + - CUDA 12.3 or higher for SM90 + - **We highly recommend 12.9 or higher for the best performance** + - CUDA 12.9 or higher for SM100 +- PyTorch 2.1 or higher +- CUTLASS 4.0 or higher (could be cloned by Git submodule) +- `{fmt}` library (could be cloned by Git submodule) ### Development ```bash # Submodule must be cloned git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git +cd DeepGEMM -# Make symbolic links for third-party (CUTLASS and CuTe) include directories -python setup.py develop +# Link some essential includes and build the CPP JIT module +cat develop.sh +./develop.sh -# Test JIT compilation -python tests/test_jit.py - -# Test all GEMM implements (normal, contiguous-grouped and masked-grouped) -python tests/test_core.py +# Test all GEMM implements +python tests/test_layout.py +python tests/test_bf16.py +python tests/test_fp8.py +python tests/test_lazy_init.py ``` ### Installation ```bash -python setup.py install +cat install.sh +./install.sh ``` Then, import `deep_gemm` in your Python project, and enjoy! @@ -75,118 +84,63 @@ Then, import `deep_gemm` in your Python project, and enjoy! #### Notices -This library exclusively contains GEMM kernels. It requires the LHS scaling factor to be TMA-aligned and transposed, and it only supports the NT format (non-transposed LHS and transposed RHS). For transposition or other FP8 casting operations, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves. +This library provides optimized GEMM kernels for NVIDIA GPUs with a naming convention: `D = C + A @ B`. The input shape layout is NT (non-transposed A, transposed B). While the SM90 implementation supports only the NT memory layout (row-major, col-major), the SM100 implementation supports all memory layouts (NT, TN, NN, TT). For example, `fp8_gemm_nt` will do a `D = C + A @ B.T` + +For both architectures, the LHS scaling factor is required to have a TMA-aligned and transposed layout. And the data format for the scaling factor of SM90 and SM100 is different: + +- SM90 requires scaling factors in FP32 format. +- SM100 requires scaling factors in packed [UE8M0](https://docs.nvidia.com/cuda/parallel-thread-execution/#alternate-floating-point-data-formats) format, which packs 4 UE8M0 into a single `torch.int`. + +Please note that operations like input transposition or FP8 casting must be handled separately by the user, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves. #### Normal dense GEMMs (non-grouped) -To perform a basic non-grouped FP8 GEMM, call the `deep_gemm.gemm_fp8_fp8_bf16_nt` function. For more details, please refer to the function documentation. +To perform a basic non-grouped FP8 GEMM, call the `fp8_gemm_{nt, nn, tn, tt}` function. For more details, please refer to the function documentation. #### Grouped GEMMs (contiguous layout) -Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape. +Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape. For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (`get_mk_alignment_for_contiguous_layout()`). For more information, please refer to the `m_grouped_fp8_gemm_{nt, nn}_contiguous` function documentation. -For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (`get_m_alignment_for_contiguous_layout()`). - -For more information, please refer to the `m_grouped_gemm_fp8_fp8_bf16_nt_contiguous` function documentation. +We also provide a K-axis-grouped API for MoE weight backward (with M and N must remain fixed), please refer to `k_grouped_fp8_gemm_tn_contiguous` for more information. #### Grouped GEMMs (masked layout) During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions. -Use `m_grouped_gemm_fp8_fp8_bf16_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input. +Use `m_grouped_fp8_gemm_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input. #### Utilities The library provides some utility functions besides the above kernels: - `deep_gemm.set_num_sms`: set the maximum SM count to use -- `deep_gemm.get_num_sms`: get the current SM maximum count -- `deep_gemm.get_m_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout +- `deep_gemm.get_num_sms`: get the current SM maximum count (return the device SM count if not set) +- `deep_gemm.set_tc_util`: set an approximated tensor core utilization ratio +- `deep_gemm.get_tc_util`: get the current tensor core utilization ratio +- `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into required layout - `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size -- `deep_gemm.get_col_major_tma_aligned_tensor`: get a column-major TMA-aligned tensor +- `deep_gemm.get_mk_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout +- `deep_gemm.get_mn_major_tma_aligned_tensor`: get a MN-major TMA-aligned tensor +- `deep_gemm.get_mn_major_tma_aligned_packed_ue8m0_tensor`: get a MN-major TMA-aligned tensor (with packing FP32 into UE8M0) +- `deep_gemm.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor`: K-grouped GEMM packing kernel The library also provides some environment variables, which may be useful: - General - - `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default + - `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default - JIT cache related - - `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default - - `DG_JIT_DISABLE_CACHE`: `0` or `1`, disable the use of cache directory, `0` by default + - `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default - NVCC/NVRTC selections - - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default - - `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default + - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default + - `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default - Compiler options - - `DG_JIT_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler, `20` by default - - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default - - `DG_JIT_PRINT_REG_REUSE`: `0` or `1`, print FFMA-interleaving details, `0` by default - - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default -- Post optimization - - `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default + - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default + - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default - Heuristic selection - - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default -- Testing - - `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default + - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation. -## Optimizations - -We indicate the techniques excluded from CUTLASS with 🐳. - -#### Persistent warp-specialization - -Following the CUTLASS design, the kernels in DeepGEMM are warp-specialized, enabling overlapping data movement, tensor-core MMA instructions, and CUDA-core promotion. A simplified figure illustrating this process is shown below: - -![design](figures/design.png) - -#### Hopper TMA features - -The [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/hopper-tuning-guide/index.html#tensor-memory-accelerator) (TMA) is a new hardware feature introduced by the Hopper architecture, designed for faster and asynchronous data movement. Specifically, we utilize TMA for: - -- TMA load for LHS, LHS scaling factors, and RHS matrices -- TMA store for the output matrix -- TMA multicast (automatically decide LHS or RHS to broadcast) -- TMA descriptor prefetching - -#### Common detail optimizations - -- Utilization of the [`stmatrix`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) PTX instruction -- [Register count control](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg) tailored for different warpgroups -- Less bank conflicts via 3D TMA or swizzling -- Larger block sizes (up to 256x128 🐳) -- Overlapping as much as possible, e.g., overlapping TMA store and non-TMA RHS scaling factor load 🐳 - -#### A unified and optimized block scheduler - -- [One scheduler](deep_gemm/include/deep_gemm/scheduler.cuh) for all non-grouped and grouped kernels -- [Rasterization](https://github.com/NVIDIA/cutlass/blob/eefa171318b79cbe2e78514d4cce5cd0fe919d0c/media/docs/efficient_gemm.md#threadblock-rasterization) to enhance L2 cache reuse - -#### Fully JIT design 🐳 - -DeepGEMM employs a fully [Just-In-Time](deep_gemm/jit) (JIT) design, with no compilation required at installation. All kernels are compiled at runtime using a lightweight JIT implementation. This approach offers several advantages: - -- GEMM shapes, block sizes, and the number of pipeline stages are treated as compile-time constants - - Saving registers - - Compilers may do more optimizations -- Automatic selection of block sizes, number of warpgroups, optimal pipeline stages, and TMA cluster size - - But without auto-tuning, the optimal one is deterministically selected -- Full unrolling of the MMA pipelines, providing compilers with more optimization opportunities - - Very important for small shapes - - Refer to `launch_k_iterations` in [the kernel file](deep_gemm/include/deep_gemm/fp8_gemm.cuh) for details - -Overall, JIT significantly improves performance for small shapes, similar to the approach of the [Triton](https://github.com/triton-lang/triton/) compiler. - -#### Unaligned block sizes 🐳 - -For certain shapes, block sizes aligned to powers of 2 can lead to underutilized SMs. For instance, with `M=256, N=7168`, a typical block size assignment of `BLOCK_M=128, BLOCK_N=128` results in only `(256 / 128) * (7168 / 128) = 112` out of 132 SMs being utilized. To address this, we support unaligned block sizes like 112, enabling `(256 / 128) * (7168 / 112) = 128` SMs to work in such scenarios. Implementing this technique alongside fine-grained scaling requires careful optimization but ultimately delivers performance gains. - -#### FFMA SASS interleaving 🐳 - -We observe a performance improvement in [the CUTLASS FP8 kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/54_hopper_fp8_warp_specialized_gemm) between NVCC 12.2 and 12.3. By comparing the compiled SASS, we discover that one bit in [a series of `FADD` instructions](https://github.com/NVIDIA/cutlass/blob/eefa171318b79cbe2e78514d4cce5cd0fe919d0c/include/cutlass/gemm/collective/fp8_accumulation.hpp#L73) is flipped in an interleaving pattern. -After referencing some open-source [CUDA assembler](https://github.com/cloudcores/CuAssembler/blob/96a9f72baf00f40b9b299653fcef8d3e2b4a3d49/CuAsm/CuControlCode.py#L46) implementations, we identified that this bit controls `yield`, which may enhance warp-level parallelism (just a guess, yielding the current warp and let other warps work). - -To leverage this, we develop [a similar script](deep_gemm/jit/interleave_ffma.py) to modify the `FFMA` instructions in the compiled binary. Besides simply modifying the `yield` bit, we also flip the `reuse` bit (registers cannot be reused if the warp is yielded). This adjustment improves performance (10%+ in some cases) for fine-grained scaling FP8 GEMMs by creating more opportunities to overlap MMA instructions with promotion `FFMA` instructions. - ## Acknowledgement DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project. Thanks and respect to the developers! @@ -194,15 +148,3 @@ DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project ## License This code repository is released under [the MIT License](LICENSE). - -## Citation - -```bibtex -@misc{deepgemm2025, - title={DeepGEMM: clean and efficient FP8 GEMM kernels with fine-grained scaling}, - author={Chenggang Zhao and Liang Zhao and Jiashi Li and Zhean Xu}, - year={2025}, - publisher = {GitHub}, - howpublished = {\url{https://github.com/deepseek-ai/DeepGEMM}}, -} -``` diff --git a/build.sh b/build.sh new file mode 100755 index 000000000..abdfc4067 --- /dev/null +++ b/build.sh @@ -0,0 +1,12 @@ +# Change current directory into project root +original_dir=$(pwd) +script_dir=$(realpath "$(dirname "$0")") +cd "$script_dir" + +# Remove old dist file, build files, and install +rm -rf build dist +rm -rf *.egg-info +python setup.py bdist_wheel + +# Open users' original directory +cd "$original_dir" diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp new file mode 100644 index 000000000..a05a36d32 --- /dev/null +++ b/csrc/apis/gemm.hpp @@ -0,0 +1,576 @@ +#pragma once + +#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" +#include "../jit_kernels/impls/sm90_bf16_gemm.hpp" +#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp" +#include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp" +#include "../jit_kernels/impls/sm100_bf16_gemm.hpp" + +#include "layout.hpp" + +namespace deep_gemm::gemm { + +static void fp8_gemm_nt(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[M, K] @ [N, K].T` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + if (fp8_requires_k_major()) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + } + + // C/D must be N-major + check_major_type_cd(d); + + // Type and shape checks + const auto& [m , k ] = get_shape<2>(a.first); + const auto& [n , k_] = get_shape<2>(b.first); + const auto& [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Check C as well + if (c.has_value()) { + check_major_type_cd(c.value()); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); + } + + // Do nothing if the problem is empty + if (m == 0) + return; + + // Transform SFA and SFB into compute-required layout + if (not recipe.has_value()) + recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); + const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast, false); + const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast, false); + + // Dispatch into different implements + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { + sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static void fp8_gemm_nn(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)}, + d, c, recipe, compiled_dims, disable_ue8m0_cast); +} + +static void fp8_gemm_tn(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, + {b.first.transpose(0, 1), b.second.transpose(0, 1)}, + d, c, recipe, compiled_dims, disable_ue8m0_cast); +} + +static void fp8_gemm_tt(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b, + d, c, recipe, compiled_dims, disable_ue8m0_cast); +} + +static void m_grouped_fp8_gemm_nt_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& m_indices, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + if (fp8_requires_k_major()) + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(m_indices.is_contiguous()); + + // Type and shape checks + const auto& [m, k] = get_shape<2>(a.first); + const auto& [num_groups, n, k_] = get_shape<3>(b.first); + const auto& [m_, n_] = get_shape<2>(d); + const auto& m__ = static_cast(m_indices.numel()); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Do nothing if empty + if (m == 0) + return; + + // Transform SFA and SFB into compute-required layout + if (not recipe.has_value()) + recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); + const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast, false); + const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast, false); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices, + num_groups, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices, + num_groups, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { + sm100_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices, + num_groups, m, n, k, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static void m_grouped_fp8_gemm_nt_contiguous_per_tensor(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& m_indices, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + if (fp8_requires_k_major()) + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(m_indices.is_contiguous()); + + // Type and shape checks + const auto& [m, k] = get_shape<2>(a.first); + const auto& [num_groups, n, k_] = get_shape<3>(b.first); + const auto& [m_, n_] = get_shape<2>(d); + const auto& m__ = static_cast(m_indices.numel()); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Do nothing if empty + if (m == 0) + return; + + // Transform SFA and SFB into compute-required layout + if (not recipe.has_value()) + recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); + const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast, true); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 and sfb.scalar_type() == torch::kFloat) { + sm90_m_grouped_fp8_gemm_contiguous_per_tensor_1d2d(a.first, b.first, sfb, d, m_indices, + num_groups, m, n, k, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static void m_grouped_fp8_gemm_nn_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const std::optional>& recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)}, + d, m_indices, recipe, compiled_dims, disable_ue8m0_cast); +} + +static void m_grouped_fp8_gemm_nt_masked(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& expected_m, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[G, M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(masked_m.is_contiguous()); + + // Type and shape checks + const auto& [num_groups, m, k] = get_shape<3>(a.first); + const auto& [num_groups_, n, k_] = get_shape<3>(b.first); + const auto& [num_groups__, m_, n_] = get_shape<3>(d); + const auto& num_groups___ = static_cast(masked_m.numel()); + DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Transform scaling factors + if (not recipe.has_value()) + recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); + const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast, false); + const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast, false); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { + sm100_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static void m_grouped_fp8_gemm_nt_masked_per_tensor(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& expected_m, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[G, M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(masked_m.is_contiguous()); + + // Type and shape checks + const auto& [num_groups, m, k] = get_shape<3>(a.first); + const auto& [num_groups_, n, k_] = get_shape<3>(b.first); + const auto& [num_groups__, m_, n_] = get_shape<3>(d); + const auto& num_groups___ = static_cast(masked_m.numel()); + DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Transform scaling factors + if (not recipe.has_value()) + recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); + const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast, true); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 and sfb.scalar_type() == torch::kFloat) { + sm90_m_grouped_fp8_gemm_masked_per_tensor_1d2d(a.first, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + + +static void k_grouped_fp8_gemm_tn_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::optional& c, + const std::tuple& recipe, + const std::string& compiled_dims) { + // Must be 1D1D kernel + DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); + + // Contiguity checks + DG_HOST_ASSERT(a.first.is_contiguous()); + DG_HOST_ASSERT(b.first.is_contiguous()); + DG_HOST_ASSERT(d.is_contiguous()); + if (c.has_value()) { + DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); + DG_HOST_ASSERT(c.value().is_contiguous()); + } + + // Do nothing if empty + if (std::accumulate(ks.begin(), ks.end(), 0) == 0) + return; + + // Transform SF with padding + const auto& [_, m] = get_shape<2>(a.first); + const auto& [__, n] = get_shape<2>(b.first); + const auto& sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe); + const auto& sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 10) { + fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, + cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void bf16_gemm_nt(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + // Shape must be `[M, K] @ [N, K].T` + const auto& major_a = get_major_type_ab(a); + const auto& major_b = get_major_type_ab(b); + + // C/D must be N-major + check_major_type_cd(d); + + // Type and shape checks + const auto& [m , k ] = get_shape<2>(a); + const auto& [n , k_] = get_shape<2>(b); + const auto& [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Check C as well + if (c.has_value()) { + check_major_type_cd(c.value()); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); + } + + // Do nothing if the problem is empty + if (m == 0) + return; + + // Dispatch into different implements + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10) { + sm100_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void bf16_gemm_nn(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + bf16_gemm_nt(a, b.transpose(0, 1), d, c, compiled_dims); +} + +static void bf16_gemm_tn(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + bf16_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c, compiled_dims); +} + +static void bf16_gemm_tt(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + bf16_gemm_nt(a.transpose(0, 1), b, d, c, compiled_dims); +} + +static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const torch::Tensor& m_indices, + const std::string& compiled_dims) { + // Shape must be `[M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a); + const auto& major_b = get_major_type_ab(b); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(m_indices.is_contiguous()); + + // Type and shape checks + const auto& [m, k] = get_shape<2>(a); + const auto& [num_groups, n, k_] = get_shape<3>(b); + const auto& [m_, n_] = get_shape<2>(d); + const auto& m__ = static_cast(m_indices.numel()); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Do nothing if empty + if (m == 0) + return; + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_m_grouped_bf16_gemm_contiguous(a, b, d, m_indices, + num_groups, m, n, k, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const torch::Tensor& masked_m, + const int& expected_m, const std::string& compiled_dims) { + // Shape must be `[G, M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a); + const auto& major_b = get_major_type_ab(b); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(masked_m.is_contiguous()); + + // Type and shape checks + const auto& [num_groups, m, k] = get_shape<3>(a); + const auto& [num_groups_, n, k_] = get_shape<3>(b); + const auto& [num_groups__, m_, n_] = get_shape<3>(d); + const auto& num_groups___ = static_cast(masked_m.numel()); + DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_bf16_m_grouped_gemm_masked(a, b, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void register_apis(pybind11::module_& m) { + // FP8 GEMMs + m.def("fp8_gemm_nt", &fp8_gemm_nt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_gemm_nn", &fp8_gemm_nn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_gemm_tn", &fp8_gemm_tn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "mn", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_gemm_tt", &fp8_gemm_tt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "mn", + py::arg("disable_ue8m0_cast") = false); + m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), + py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), + py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("m_grouped_fp8_gemm_nt_masked", &m_grouped_fp8_gemm_nt_masked, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), + py::arg("expected_m"), py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); + m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), + py::arg("ks_tensor"), py::arg("c") = std::nullopt, + py::arg("recipe") = std::make_tuple(1, 1, 128), + py::arg("compiled_dims") = "mn"); + + // BF16 GEMMs + m.def("bf16_gemm_nt", &bf16_gemm_nt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "nk"); + m.def("bf16_gemm_nn", &bf16_gemm_nn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "nk"); + m.def("bf16_gemm_tn", &bf16_gemm_tn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "mn"); + m.def("bf16_gemm_tt", &bf16_gemm_tt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "mn"); + m.def("m_grouped_bf16_gemm_nt_contiguous", &m_grouped_bf16_gemm_nt_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), + py::arg("compiled_dims") = "nk"); + m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), + py::arg("expected_m"), py::arg("compiled_dims") = "nk"); + + // Per Tensor GEMMs + m.def("m_grouped_fp8_gemm_nt_contiguous_per_tensor", &m_grouped_fp8_gemm_nt_contiguous_per_tensor, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), + py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("m_grouped_fp8_gemm_nt_masked_per_tensor", &m_grouped_fp8_gemm_nt_masked_per_tensor, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), + py::arg("expected_m"), py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); +} + +} // namespace deep_gemm::gemm diff --git a/csrc/apis/layout.hpp b/csrc/apis/layout.hpp new file mode 100644 index 000000000..9d2afc5cf --- /dev/null +++ b/csrc/apis/layout.hpp @@ -0,0 +1,93 @@ +#pragma once + +#include "../utils/layout.hpp" +#include "../jit_kernels/impls/smxx_layout.hpp" + +namespace deep_gemm::layout { + +static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, + const int& mn, const int& k, + const std::tuple& recipe, + const std::optional& num_groups, + const bool& is_sfa, + const bool& disable_ue8m0_cast, + const bool& is_per_tensor) { + const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe); + const auto& gran_k = std::get<2>(recipe); + const auto& arch_major = device_runtime->get_arch_major(); + + // Pre-transform checks + if (is_per_tensor) + check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, false, std::nullopt, true); + else + check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, false, std::nullopt, false); + + // (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) + return get_mn_major_tma_aligned_tensor(sf); + + // (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) { + DG_HOST_ASSERT(not disable_ue8m0_cast); + return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf); + } + + // (FP32, 128, 128) on SM90: no need to transform, check shape and contiguous + if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast) and is_per_tensor == false) + return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat); + + // (FP32, 128, 128) per tensor on SM90: no need to transform, check shape and contiguous + if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast) and is_per_tensor) + return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat, true); + + // (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) { + DG_HOST_ASSERT(not disable_ue8m0_cast); + const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128)); + return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted); + } + + // (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major + if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10) + return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt); + + DG_HOST_UNREACHABLE("Unknown SF transformation"); +} + +static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::tuple& recipe) { + DG_HOST_ASSERT(sf.dim() == 2); + DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); + const auto& arch_major = device_runtime->get_arch_major(); + + // FP32 on SM90 + if (sf.scalar_type() == torch::kFloat and arch_major == 9) + DG_HOST_UNREACHABLE("Unimplemented"); + + // FP32 on SM100 + if (sf.scalar_type() == torch::kFloat and arch_major == 10) + return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks); + + // INT on SM100 + if (sf.scalar_type() == torch::kFloat and arch_major == 10) + DG_HOST_UNREACHABLE("Unimplemented"); + + DG_HOST_UNREACHABLE("Unknown cases"); +} + +static void register_apis(pybind11::module_& m) { + m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout, + py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"), + py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false, + py::arg("disable_ue8m0_cast") = false, py::arg("is_per_tensor") = false); + + m.def("get_tma_aligned_size", &get_tma_aligned_size); + m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout); + m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor); + m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor); + m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor); +} + +} // namespace deep_gemm::layout diff --git a/csrc/apis/runtime.hpp b/csrc/apis/runtime.hpp new file mode 100644 index 000000000..9ef420785 --- /dev/null +++ b/csrc/apis/runtime.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include "../jit/compiler.hpp" +#include "../jit/device_runtime.hpp" + +namespace deep_gemm::runtime { + +static void register_apis(pybind11::module_& m) { + m.def("set_num_sms", [&](const int& new_num_sms) { + device_runtime->set_num_sms(new_num_sms); + }); + m.def("get_num_sms", [&]() { + return device_runtime->get_num_sms(); + }); + m.def("set_tc_util", [&](const int& new_tc_util) { + device_runtime->set_tc_util(new_tc_util); + }); + m.def("get_tc_util", [&]() { + return device_runtime->get_tc_util(); + }); + + m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) { + Compiler::prepare_init(library_root_path, cuda_home_path_by_python); + KernelRuntime::prepare_init(cuda_home_path_by_python); + }); +} + +} // namespace deep_gemm::runtime diff --git a/csrc/indexing/main.cu b/csrc/indexing/main.cu new file mode 100644 index 000000000..a05b59c85 --- /dev/null +++ b/csrc/indexing/main.cu @@ -0,0 +1,13 @@ +#include +#include +#include +#include +#include +#include +#include + +using namespace deep_gemm; + +int main() { + return 0; +} diff --git a/csrc/jit/cache.hpp b/csrc/jit/cache.hpp new file mode 100644 index 000000000..1e8659fd3 --- /dev/null +++ b/csrc/jit/cache.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include + +#include "kernel_runtime.hpp" + +namespace deep_gemm { + +class KernelRuntimeCache { + std::unordered_map> cache; + +public: + // TODO: consider cache capacity + KernelRuntimeCache() = default; + + std::shared_ptr get(const std::filesystem::path& dir_path) { + // Hit the runtime cache + if (const auto& iterator = cache.find(dir_path); iterator != cache.end()) + return iterator->second; + + if (KernelRuntime::check_validity(dir_path)) + return cache[dir_path] = std::make_shared(dir_path); + return nullptr; + } +}; + +static auto kernel_runtime_cache = std::make_shared(); + +} // namespace deep_gemm diff --git a/csrc/jit/compiler.hpp b/csrc/jit/compiler.hpp new file mode 100644 index 000000000..09c308739 --- /dev/null +++ b/csrc/jit/compiler.hpp @@ -0,0 +1,277 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "../utils/exception.hpp" +#include "../utils/format.hpp" +#include "../utils/hash.hpp" +#include "../utils/lazy_init.hpp" +#include "../utils/system.hpp" +#include "cache.hpp" +#include "device_runtime.hpp" + +namespace deep_gemm { + +class Compiler { +public: + static std::filesystem::path library_root_path; + static std::filesystem::path library_include_path; + static std::filesystem::path cuda_home; + static std::string library_version; + + static std::string get_library_version() { + std::stringstream ss; + for (const auto& f: collect_files(library_include_path / "deep_gemm")) { + std::ifstream in(f, std::ios::binary); + ss << in.rdbuf(); + } + return get_hex_digest(ss.str()); + } + + static void prepare_init(const std::string& library_root_path, + const std::string& cuda_home_path_by_python) { + Compiler::library_root_path = library_root_path; + Compiler::library_include_path = Compiler::library_root_path / "include"; + Compiler::cuda_home = cuda_home_path_by_python; + Compiler::library_version = get_library_version(); + } + + std::string signature, flags; + std::filesystem::path cache_dir_path; + + Compiler() { + // Check `prepare_init` + DG_HOST_ASSERT(not library_root_path.empty()); + DG_HOST_ASSERT(not library_include_path.empty()); + DG_HOST_ASSERT(not cuda_home.empty()); + DG_HOST_ASSERT(not library_version.empty()); + + // Cache settings + cache_dir_path = std::filesystem::path(get_env("HOME")) / ".deep_gemm"; + if (const auto& env_cache_dir_path = get_env("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty()) + cache_dir_path = env_cache_dir_path; + + // The compiler flags applied to all derived compilers + signature = "unknown-compiler"; + flags = fmt::format("-std=c++{} --diag-suppress=39,161,174,177,186,940 " + "--ptxas-options=--register-usage-level=10", + get_env("DG_JIT_CPP_STANDARD", 20)); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0)) + flags += " --ptxas-options=--verbose"; + if (get_env("DG_JIT_WITH_LINEINFO", 0)) + flags += " -Xcompiler -rdynamic -lineinfo"; + } + + virtual ~Compiler() = default; + + std::filesystem::path make_tmp_dir() const { + return make_dirs(cache_dir_path / "tmp"); + } + + std::filesystem::path get_tmp_file_path() const { + return make_tmp_dir() / get_uuid(); + } + + void put(const std::filesystem::path& path, const std::string& data) const { + const auto tmp_file_path = get_tmp_file_path(); + + // Write into the temporary file + std::ofstream out(tmp_file_path, std::ios::binary); + DG_HOST_ASSERT(out.write(data.data(), data.size())); + out.close(); + + // Atomically replace + std::filesystem::rename(tmp_file_path, path); + } + + std::shared_ptr build(const std::string& name, const std::string& code) const { + const auto kernel_signature = fmt::format("{}$${}$${}$${}$${}", name, library_version, signature, flags, code); + const auto dir_path = cache_dir_path / "cache" / fmt::format("kernel.{}.{}", name, get_hex_digest(kernel_signature)); + + // Hit the runtime cache + if (const auto& runtime = kernel_runtime_cache->get(dir_path); runtime != nullptr) + return runtime; + + // Create the kernel directory + make_dirs(dir_path); + + // Compile into a temporary CUBIN + const auto tmp_cubin_path = get_tmp_file_path(); + compile(code, dir_path, tmp_cubin_path); + + // Replace into the cache directory + make_dirs(dir_path); + std::filesystem::rename(tmp_cubin_path, dir_path / "kernel.cubin"); + + // Put into the runtime cache + const auto& runtime = kernel_runtime_cache->get(dir_path); + DG_HOST_ASSERT(runtime != nullptr); + return runtime; + } + + virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const = 0; +}; + +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_root_path); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_include_path); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuda_home); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_version); + +class NVCCCompiler final: public Compiler { + std::filesystem::path nvcc_path; + + std::pair get_nvcc_version() const { + DG_HOST_ASSERT(std::filesystem::exists(nvcc_path)); + + // Call the version command + const auto& command = std::string(nvcc_path) + " --version"; + const auto& [return_code, output] = call_external_command(command); + DG_HOST_ASSERT(return_code == 0); + + // The version should be at least 12.3, for the best performance with 12.9 + int major, minor; + std::smatch match; + DG_HOST_ASSERT(std::regex_search(output, match, std::regex(R"(release (\d+\.\d+))"))); + std::sscanf(match[1].str().c_str(), "%d.%d", &major, &minor); + DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVCC version should be >= 12.3"); + if (major < 12 or (major == 12 and minor < 9)) + printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance"); + return {major, minor}; + } + +public: + NVCCCompiler() { + // Override the compiler signature + nvcc_path = cuda_home / "bin" / "nvcc"; + if (const auto& env_nvcc_path = get_env("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty()) + nvcc_path = env_nvcc_path; + const auto& [nvcc_major, nvcc_minor] = get_nvcc_version(); + signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor); + + // The override the compiler flags + flags = fmt::format("{} -I{} --gpu-architecture=sm_{} " + "--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi " + "-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda", + flags, library_include_path.c_str(), device_runtime->get_arch()); + } + + void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override { + // Write the code into the cache directory + const auto& code_path = dir_path / "kernel.cu"; + put(code_path, code); + + // Compile + const auto& command = fmt::format("{} {} -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) + printf("Running NVCC command: %s\n", command.c_str()); + const auto& [return_code, output] = call_external_command(command); + if (return_code != 0) { + printf("NVCC compilation failed: %s\n", output.c_str()); + DG_HOST_ASSERT(false and "NVCC compilation failed"); + } + + // Print PTXAS log + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0)) + printf("%s", output.c_str()); + } +}; + +class NVRTCCompiler final: public Compiler { +public: + NVRTCCompiler() { + // Override the compiler signature + int major, minor; + DG_NVRTC_CHECK(nvrtcVersion(&major, &minor)); + signature = fmt::format("NVRTC{}.{}", major, minor); + + // Build include directories list + std::string include_dirs; + include_dirs += fmt::format("-I{} ", library_include_path.string()); + include_dirs += fmt::format("-I{} ", (cuda_home / "include").string()); + + // Add PCH support for version 12.8 and above + // NOTES: PCH is vital for compilation speed + std::string pch_flags; + if (major > 12 or (major == 12 and minor >= 8)) { + pch_flags = "--pch "; + if (get_env("DG_JIT_DEBUG", 0)) + pch_flags += "--pch-verbose=true "; + } + + // Override the compiler flags + flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {}", + flags, include_dirs, device_runtime->get_arch(), pch_flags); + } + + void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override { + // Write the code into the cache directory + const auto& code_path = dir_path / "kernel.cu"; + put(code_path, code); + + // Parse compilation options + std::istringstream iss(flags); + std::vector options; + std::string option; + while (iss >> option) + options.push_back(option); + + // Convert to C-style string array for NVRTC + std::vector option_cstrs; + for (const auto& opt: options) + option_cstrs.push_back(opt.c_str()); + + // Print compiler command if requested + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) { + printf("Compiling JIT runtime with NVRTC options: "); + for (const auto& opt: options) + printf("%s ", opt.c_str()); + printf("\n"); + } + + // Create NVRTC program and compile + nvrtcProgram program; + DG_NVRTC_CHECK(nvrtcCreateProgram(&program, code.c_str(), "kernel.cu", 0, nullptr, nullptr)); + const auto& compile_result = nvrtcCompileProgram(program, static_cast(option_cstrs.size()), option_cstrs.data()); + + // Get and print compiler log + size_t log_size; + DG_NVRTC_CHECK(nvrtcGetProgramLogSize(program, &log_size)); + if (get_env("DG_JIT_DEBUG", 0) or compile_result != NVRTC_SUCCESS) { + if (compile_result != NVRTC_SUCCESS) + DG_HOST_ASSERT(log_size > 1); + if (log_size > 1) { + std::string compilation_log(log_size, '\0'); + DG_NVRTC_CHECK(nvrtcGetProgramLog(program, compilation_log.data())); + printf("NVRTC log: %s\n", compilation_log.c_str()); + } + } + + // Get CUBIN size and data + size_t cubin_size; + DG_NVRTC_CHECK(nvrtcGetCUBINSize(program, &cubin_size)); + std::string cubin_data(cubin_size, '\0'); + DG_NVRTC_CHECK(nvrtcGetCUBIN(program, cubin_data.data())); + + // Write into the file system + put(cubin_path, cubin_data); + + // Cleanup + DG_NVRTC_CHECK(nvrtcDestroyProgram(&program)); + } +}; + +static auto compiler = LazyInit([]() -> std::shared_ptr { + if (get_env("DG_JIT_USE_NVRTC", 0)) { + return std::make_shared(); + } else { + return std::make_shared(); + } +}); + +} // namespace deep_gemm diff --git a/csrc/jit/device_runtime.hpp b/csrc/jit/device_runtime.hpp new file mode 100644 index 000000000..310942d9b --- /dev/null +++ b/csrc/jit/device_runtime.hpp @@ -0,0 +1,62 @@ +#pragma once + +#include + +#include "../utils/exception.hpp" +#include "../utils/lazy_init.hpp" + +namespace deep_gemm { + +class DeviceRuntime { + int num_sms = 0, tc_util = 0; + std::shared_ptr cached_prop; + +public: + explicit DeviceRuntime() = default; + + std::shared_ptr get_prop() { + if (cached_prop == nullptr) + cached_prop = std::make_shared(*at::cuda::getCurrentDeviceProperties()); + return cached_prop; + } + + std::pair get_arch_pair() { + const auto prop = get_prop(); + return {prop->major, prop->minor}; + } + + std::string get_arch() { + const auto& [major, minor] = get_arch_pair(); + if (major == 10 and minor != 1) + return "100f"; + return std::to_string(major * 10 + minor) + "a"; + } + + int get_arch_major() { + return get_arch_pair().first; + } + + void set_num_sms(const int& new_num_sms) { + DG_HOST_ASSERT(0 <= new_num_sms and new_num_sms <= get_prop()->multiProcessorCount); + num_sms = new_num_sms; + } + + int get_num_sms() { + if (num_sms == 0) + num_sms = get_prop()->multiProcessorCount; + return num_sms; + } + + void set_tc_util(const int& new_tc_util) { + DG_HOST_ASSERT(0 <= new_tc_util and new_tc_util <= 100); + tc_util = new_tc_util; + } + + int get_tc_util() const { + return tc_util == 0 ? 100 : tc_util; + } +}; + +static auto device_runtime = LazyInit([](){ return std::make_shared(); }); + +} // namespace deep_gemm diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp new file mode 100644 index 000000000..1875d54ca --- /dev/null +++ b/csrc/jit/handle.hpp @@ -0,0 +1,135 @@ +#pragma once + +#include +#include +#include + +#include "../utils/exception.hpp" + +namespace deep_gemm { + +#if CUDART_VERSION >= 12080 and not defined(DG_JIT_USE_DRIVER_API) + +// Use CUDA runtime API +using LibraryHandle = cudaLibrary_t; +using KernelHandle = cudaKernel_t; +using LaunchConfigHandle = cudaLaunchConfig_t; +using LaunchAttrHandle = cudaLaunchAttribute; + +#define DG_CUDA_UNIFIED_CHECK DG_CUDA_RUNTIME_CHECK + +static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name, + LibraryHandle *library_opt = nullptr) { + LibraryHandle library; + KernelHandle kernel{}; + DG_CUDA_RUNTIME_CHECK(cudaLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + DG_CUDA_RUNTIME_CHECK(cudaLibraryGetKernel(&kernel, library, func_name.c_str())); + + if (library_opt != nullptr) + *library_opt = library; + return kernel; +} + +static void unload_library(const LibraryHandle& library) { + const auto& error = cudaLibraryUnload(library); + DG_HOST_ASSERT(error == cudaSuccess or error == cudaErrorCudartUnloading); +} + +static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, + const cudaStream_t& stream, const int& smem_size, + const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { + if (smem_size > 0) + DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + LaunchConfigHandle config; + config.gridDim = grid_dim; + config.blockDim = block_dim; + config.dynamicSmemBytes = smem_size; + config.stream = stream; + config.numAttrs = 0; + config.attrs = nullptr; + + // NOTES: must use `static` or the `attr` will be deconstructed + static LaunchAttrHandle attr; + if (cluster_dim > 1) { + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {static_cast(cluster_dim), 1, 1}; + config.attrs = &attr; + config.numAttrs = 1; + } + return config; +} + +template +static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& config, ActTypes&&... args) { + void *ptr_args[] = { &args... }; + return cudaLaunchKernelExC(&config, kernel, ptr_args); +} + +#else + +// Use CUDA driver API +using LibraryHandle = CUmodule; +using KernelHandle = CUfunction; +using LaunchConfigHandle = CUlaunchConfig; +using LaunchAttrHandle = CUlaunchAttribute; + +#define DG_CUDA_UNIFIED_CHECK DG_CUDA_DRIVER_CHECK + +static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name, + LibraryHandle *library_opt = nullptr) { + LibraryHandle library; + KernelHandle kernel; + DG_CUDA_DRIVER_CHECK(cuModuleLoad(&library, cubin_path.c_str())); + DG_CUDA_DRIVER_CHECK(cuModuleGetFunction(&kernel, library, func_name.c_str())); + + if (library_opt != nullptr) + *library_opt = library; + return kernel; +} + +static void unload_library(const LibraryHandle& library) { + const auto& error = cuModuleUnload(library); + DG_HOST_ASSERT(error == CUDA_SUCCESS or error == CUDA_ERROR_DEINITIALIZED); +} + +static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, + const cudaStream_t& stream, const int& smem_size, + const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { + if (smem_size > 0) + DG_CUDA_DRIVER_CHECK(cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size)); + + LaunchConfigHandle config; + config.gridDimX = grid_dim.x; + config.gridDimY = grid_dim.y; + config.gridDimZ = grid_dim.z; + config.blockDimX = block_dim.x; + config.blockDimY = block_dim.y; + config.blockDimZ = block_dim.z; + config.sharedMemBytes = smem_size; + config.hStream = stream; + config.numAttrs = 0; + config.attrs = nullptr; + + // NOTES: must use `static` or the `attr` will be deconstructed + static LaunchAttrHandle attr; + if (cluster_dim > 1) { + attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + attr.value.clusterDim.x = cluster_dim; + attr.value.clusterDim.y = 1; + attr.value.clusterDim.z = 1; + config.attrs = &attr; + config.numAttrs = 1; + } + return config; +} + +template +static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& config, ActTypes&&... args) { + void *ptr_args[] = { &args... }; + return cuLaunchKernelEx(&config, kernel, ptr_args, nullptr); +} + +#endif + +} // namespace deep_gemm diff --git a/csrc/jit/kernel_runtime.hpp b/csrc/jit/kernel_runtime.hpp new file mode 100644 index 000000000..42b7b4cb5 --- /dev/null +++ b/csrc/jit/kernel_runtime.hpp @@ -0,0 +1,117 @@ +#pragma once + +#include "../utils/exception.hpp" +#include "../utils/format.hpp" +#include "../utils/system.hpp" +#include "device_runtime.hpp" +#include "handle.hpp" + +namespace deep_gemm { + +struct LaunchArgs { + std::pair grid_dim; + int num_threads; + int smem_size; + int cluster_dim; + + LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): + grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} + + LaunchArgs(const std::pair& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): + grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} +}; + +class KernelRuntime final { +public: + static std::filesystem::path cuda_home; + + LibraryHandle library; + KernelHandle kernel; + + explicit KernelRuntime(const std::filesystem::path& dir_path) { + // Check `prepare_init` + DG_HOST_ASSERT(not cuda_home.empty()); + + // NOLINT(*-pro-type-member-init) + const auto& cuobjdump_path = cuda_home / "bin" / "cuobjdump"; + const auto& cubin_path = dir_path / "kernel.cubin"; + if (get_env("DG_JIT_DEBUG")) + printf("Loading CUBIN: %s\n", cubin_path.c_str()); + + // Find the only symbol + // TODO: use kernel enumeration for newer drivers + const std::vector illegal_names = {"vprintf", "__instantiate_kernel", "__internal", "__assertfail"}; + const auto& [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.c_str(), cubin_path.c_str())); + DG_HOST_ASSERT(exit_code == 0); + std::istringstream iss(symbols); + std::vector symbol_names; + for (std::string line; std::getline(iss, line); ) { + if (line.find("STT_FUNC") == 0 and std::none_of(illegal_names.begin(), illegal_names.end(), + [&](const auto& name) { return line.find(name) != std::string::npos; })) { + const auto& last_space = line.rfind(' '); + symbol_names.push_back(line.substr(last_space + 1)); + } + } + if (get_env("DG_JIT_DEBUG")) { + printf("Symbol names: "); + for (const auto& symbol: symbol_names) + printf("%s, ", symbol.c_str()); + printf("\n"); + } + + // Load from the library + DG_HOST_ASSERT(symbol_names.size() == 1); + kernel = load_kernel(cubin_path, symbol_names[0], &library); + } + + static void prepare_init(const std::string& cuda_home_path_by_python) { + cuda_home = cuda_home_path_by_python; + } + + static bool check_validity(const std::filesystem::path& dir_path) { + return std::filesystem::exists(dir_path / "kernel.cu") and + std::filesystem::exists(dir_path / "kernel.cubin"); + } + + ~KernelRuntime() noexcept(false) { + unload_library(library); + } +}; + +DG_DECLARE_STATIC_VAR_IN_CLASS(KernelRuntime, cuda_home); + +template +class LaunchRuntime { +public: + template + static std::string generate(const Args& args) { + const auto& code = Derived::generate_impl(args); + if (get_env("DG_JIT_DEBUG", 0)) + printf("Generated kernel code: %s\n", code.c_str()); + return code; + } + + template + static void launch(const std::shared_ptr& kernel_runtime, const Args& args) { + const auto& kernel = kernel_runtime->kernel; + const auto& stream = at::cuda::getCurrentCUDAStream(); + const LaunchArgs& launch_args = args.launch_args; + + const dim3& grid_dim = {static_cast(launch_args.grid_dim.first), + static_cast(launch_args.grid_dim.second), + 1}; + const dim3& block_dim = {static_cast(launch_args.num_threads), 1, 1}; + auto config = construct_launch_config(kernel, stream, launch_args.smem_size, + grid_dim, block_dim, launch_args.cluster_dim); + + // Launch in the derived class + if (get_env("DG_JIT_DEBUG")) { + printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, stream: %ld\n", + launch_args.grid_dim.first, launch_args.grid_dim.second, launch_args.num_threads, + launch_args.smem_size, launch_args.cluster_dim, stream.id()); + } + Derived::launch_impl(kernel, config, args); + } +}; + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp new file mode 100644 index 000000000..ecf237367 --- /dev/null +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -0,0 +1,307 @@ +#pragma once + +#include "../../utils/math.hpp" +#include "../../utils/layout.hpp" + +namespace deep_gemm { + +struct MulticastConfig { + int num_multicast; + bool is_multicast_on_a; + + MulticastConfig(const int& num_multicast, const bool& is_multicast_on_a): + num_multicast(num_multicast), is_multicast_on_a(is_multicast_on_a) { + DG_HOST_ASSERT(1 <= num_multicast and num_multicast <= 2); + } +}; + +struct SharedMemoryConfig { + int smem_size; + int swizzle_a_mode; + int swizzle_b_mode; + int swizzle_cd_mode; +}; + +struct ThreadConfig { + int num_threads; + + // SM90 + int num_tma_threads; + int num_math_threads; + + // SM100 + int num_non_epilogue_threads; + int num_epilogue_threads; + + static ThreadConfig sm90(const int& num_tma_threads, + const int& num_math_threads) { + auto config = ThreadConfig(); + config.num_threads = num_tma_threads + num_math_threads; + config.num_tma_threads = num_tma_threads; + config.num_math_threads = num_math_threads; + return config; + } + + static ThreadConfig sm100(const int& num_non_epilogue_threads, + const int& num_epilogue_threads) { + auto config = ThreadConfig(); + config.num_threads = num_non_epilogue_threads + num_epilogue_threads; + config.num_non_epilogue_threads = num_non_epilogue_threads; + config.num_epilogue_threads = num_epilogue_threads; + return config; + } +}; + +struct GemmConfig { + // Templated configs + GemmType gemm_type; + KernelType kernel_type; + at::ScalarType ab_dtype, cd_dtype; + cute::UMMA::Major major_a; + cute::UMMA::Major major_b; + bool with_accumulation; + int block_m, block_n, block_k; + int num_stages, num_last_stages; + + // Templated device configs + int num_sms; + int tc_util; + + // Structured configs + MulticastConfig multicast_config; + SharedMemoryConfig smem_config; + ThreadConfig thread_config; +}; + +static bool is_multicast_legal(const int& shape_dim, const int& block_dim, + const int& num_multicast, const int& num_sms, + const bool& require_divisible) { + const bool& divisible = ceil_div(shape_dim, block_dim) % num_multicast == 0 or not require_divisible; + return divisible and num_sms % num_multicast == 0; +} + +static int get_swizzle_mode(const int& block_size, const int& elem_size) { + // `> 0` means interleaving + // 16B actually means non-swizzling (but interleaving) + for (const int& mode: {128, 64, 32, 16}) { + if ((block_size * elem_size) % mode == 0) + return mode; + } + DG_HOST_UNREACHABLE("Unreachable"); +} + +template +static SharedMemoryConfig get_smem_config(const KernelType& kernel_type, + const int& m, const int& n, const int& k, + const int& block_m, const int& block_n, const int& block_k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const int& num_stages, const MulticastConfig& multicast_config, bool is_per_tensor = false) { + const int& ab_elem_size = static_cast(c10::elementSize(ab_dtype)); + const int& cd_elem_size = static_cast(c10::elementSize(cd_dtype)); + + const int& load_block_m = ArchSpec::get_ab_load_block_m(multicast_config, block_m); + const int& load_block_n = ArchSpec::get_ab_load_block_n(multicast_config, block_n); + const int& swizzle_a_mode = get_swizzle_mode(major_a == cute::UMMA::Major::K ? block_k : load_block_m, ab_elem_size); + const int& swizzle_b_mode = get_swizzle_mode(major_b == cute::UMMA::Major::K ? block_k : load_block_n, ab_elem_size); + const int& swizzle_cd_mode = get_swizzle_mode(block_n, cd_elem_size); + + // Different archs have different epilogue pipelines + const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype); + + // A/B shared memory + const int& smem_a_per_stage = load_block_m * block_k * ab_elem_size; + const int& smem_b_per_stage = load_block_n * block_k * ab_elem_size; + + // SF shared memory + const auto& [smem_sfa_per_stage, smem_sfb_per_stage] = + ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, ab_dtype, cd_dtype, is_per_tensor); + const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k, is_per_tensor); + + // M-barriers and tensor memory pointers + const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages); + const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size(); + + // Sum them up + int smem_size = 0; + smem_size += smem_cd; + smem_size += num_stages * smem_a_per_stage; + smem_size += num_stages * smem_b_per_stage; + smem_size += num_stages * smem_sfa_per_stage; + smem_size += num_stages * smem_sfb_per_stage; + smem_size += smem_extra_sfb; + smem_size += smem_barrier; + smem_size += smem_tmem_ptr; + + return SharedMemoryConfig { + .smem_size = smem_size, + .swizzle_a_mode = swizzle_a_mode, + .swizzle_b_mode = swizzle_b_mode, + .swizzle_cd_mode = swizzle_cd_mode, + }; +} + +template +static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type, + const int& m, const int& n, const int& k, const int& num_groups, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const bool& with_accumulation, const int& num_sms, bool is_per_tensor = false) { + DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16); + DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); + + // Select M/N block sizes + // TODO: support `% 16 == 8` block size on SM90 + const auto& block_ms = gemm_type == GemmType::MGroupedContiguous ? + std::vector{get_mk_alignment_for_contiguous_layout()} : std::vector{64, 128, 256}; + std::vector block_ns; + for (int i = 16; i <= 256; i += 16) + block_ns.push_back(i); + + // K block size is selected in a fixed manner + const auto& block_k = 128 / static_cast(c10::elementSize(ab_dtype)); + + // Some util functions + const auto& get_num_blocks = [=](const int& block_m, const int& block_n) { + return ceil_div(m, block_m) * ceil_div(n, block_n) * num_groups; + }; + const auto& get_num_waves = [=](const int& block_m, const int& block_n) { + return ceil_div(get_num_blocks(block_m, block_n), num_sms); + }; + const auto& get_last_wave_util = [=](const int& block_m, const int& block_n) { + const auto& num_last_blocks = get_num_blocks(block_m, block_n) % num_sms; + return num_last_blocks == 0 ? num_sms : num_last_blocks; + }; + + // Decide block sizes by waves + int best_block_m = 0, best_block_n = 0; + int best_num_waves = 0, best_last_util = 0; + for (const auto& block_m: block_ms) { + for (const auto& block_n: block_ns) { + const int& num_waves = get_num_waves(block_m, block_n); + const auto& last_util = get_last_wave_util(block_m, block_n); + if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n, block_k)) + continue; + + bool success = false; + if (best_block_m == 0 or best_block_n == 0 or num_waves < best_num_waves) { + success = true; + } else if (num_waves == best_num_waves) { + // Check last wave utilization + success = last_util > best_last_util; + if (last_util == best_last_util) { + // Case 1: same `block_m`, smaller `block_n` (wasted) + success |= block_m == best_block_m and block_n < best_block_n; + // Case 2: same `block_n`, smaller `block_m` (wasted) + success |= block_n == best_block_n and block_m < best_block_m; + // Case 3: different for both `block_m` and `block_n`, larger `block_n` is better + // NOTES: don't pick `block_m/block_n` larger than shape `m/n` in this case + success |= block_m != best_block_m and block_n > best_block_n + and block_n <= n and block_m <= m; + } + } + + // Replace with the new config if successful + if (success) { + best_block_m = block_m, best_block_n = block_n; + best_num_waves = num_waves, best_last_util = last_util; + } + } + } + DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0); + + // Decide the number of TMA multicasts and whether broadcast on A + MulticastConfig best_multicast_config = {1, true}; + const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality( + gemm_type, m, n, best_block_m, best_block_n, num_sms); + const bool is_legal[2] = {is_legal_on_a, is_legal_on_b}; + bool order[2] = {false, true}; + if (best_block_m > best_block_n) + std::swap(order[0], order[1]); + for (const bool& is_multicast_on_a: order) { + if (m >= 512 and is_legal[static_cast(is_multicast_on_a)]) { + best_multicast_config = {2, is_multicast_on_a}; + break; + } + } + + // Always pick the largest number of stage + constexpr int smem_capacity = ArchSpec::smem_capacity; + int best_num_stages = 0; + SharedMemoryConfig best_smem_config; + for (int num_stages = std::min(12, ceil_div(k, block_k)); num_stages > 0; -- num_stages) { + if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k)) + continue; + + best_smem_config = get_smem_config(kernel_type, + m, n, k, + best_block_m, best_block_n, block_k, + major_a, major_b, + ab_dtype, cd_dtype, + num_stages, best_multicast_config, is_per_tensor); + if (best_smem_config.smem_size <= smem_capacity) { + best_num_stages = num_stages; + break; + } + } + DG_HOST_ASSERT(best_num_stages != 0); + + // Recompute the minimal number of SMs required + // NOTES: less L2 cache usage and less GPU frequency drop + int num_min_sms = num_sms; + if (ArchSpec::should_minimize_num_sms()) { + num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves); + num_min_sms = align(num_min_sms, best_multicast_config.num_multicast); + DG_HOST_ASSERT(num_min_sms <= num_sms); + } + + const auto& config = GemmConfig { + .gemm_type = gemm_type, + .kernel_type = kernel_type, + .ab_dtype = ab_dtype, + .cd_dtype = cd_dtype, + .major_a = major_a, + .major_b = major_b, + .with_accumulation = with_accumulation, + .block_m = best_block_m, + .block_n = best_block_n, + .block_k = block_k, + .num_stages = best_num_stages, + .num_last_stages = ceil_div(k, block_k) % best_num_stages, + .num_sms = num_min_sms, + .tc_util = device_runtime->get_tc_util(), + .multicast_config = best_multicast_config, + // ReSharper disable once CppLocalVariableMightNotBeInitialized + .smem_config = best_smem_config, + .thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n) + }; + + // Only SM100 BF16 kernels support tensor core control + if (config.tc_util < 100) + DG_HOST_ASSERT(device_runtime->get_arch_major() == 10 and ab_dtype == torch::kBFloat16); + + // Print configs for the first time + if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { + auto key = std::make_tuple(gemm_type, kernel_type, m, n, k, num_groups, major_a, major_b, + ab_dtype, cd_dtype, with_accumulation, num_sms); + static std::set printed; + if (printed.count(key) == 0) { + printf("GEMM type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, " + "A major: %d, B major: %d, AB dtype: %s, CD dtype: %s, accumulation: %d, " + "SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, " + "SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, " + "swizzle B: %d, swizzle CD: %d, SMs: %d, threads: %d, TC util: %d%%\n", + static_cast(gemm_type), static_cast(kernel_type), m, n, k, num_groups, + static_cast(major_a), static_cast(major_b), c10::toString(ab_dtype), c10::toString(cd_dtype), + static_cast(with_accumulation), num_sms, best_block_m, best_block_n, block_k, + best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast, + static_cast(best_multicast_config.is_multicast_on_a), + best_smem_config.smem_size, best_smem_config.swizzle_a_mode, best_smem_config.swizzle_b_mode, + best_smem_config.swizzle_cd_mode, config.num_sms, config.thread_config.num_threads, config.tc_util); + printed.insert(key); + } + } + return config; +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp new file mode 100644 index 000000000..064ffa1aa --- /dev/null +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -0,0 +1,151 @@ +#pragma once + +#include +// Reuse some types in the JIT modules +#include + +#include "common.hpp" +#include "../../utils/exception.hpp" + +namespace deep_gemm { + +struct SM100ArchSpec { + static constexpr int smem_capacity = 232448; + + static int get_ab_load_block_m(const MulticastConfig& config, const int& block_m) { + return block_m / (config.is_multicast_on_a ? config.num_multicast : 1); + } + + static int get_ab_load_block_n(const MulticastConfig& config, const int& block_n) { + return block_n / (config.is_multicast_on_a ? 1 : config.num_multicast); + } + + static int get_cd_store_block_m(const int& block_m) { + constexpr int layout_ad_m = 128; + return std::min(block_m, layout_ad_m); + } + + static int get_cd_store_block_n(const int& block_n) { + return block_n; + } + + static std::pair get_sf_uttcp_aligned_block_sizes( + const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) { + constexpr int num_utccp_aligned_elems = 128; + DG_HOST_ASSERT(block_m % num_utccp_aligned_elems == 0); + switch (ab_dtype) { + case torch::kBFloat16: return {0, 0}; + case torch::kFloat8_e4m3fn: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)}; + default: DG_HOST_UNREACHABLE("Unknown dtype"); + } + } + + static bool is_block_size_legal(const KernelType& kernel_type, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const int& block_m, const int& block_n, const int& block_k) { + // TODO: consider more carefully for BF16 GEMMs + // 2SM BF16 UMMA does not support `N % 32 != 0` + if (ab_dtype == torch::kBFloat16 and block_n % 32 != 0) + return false; + + // Layout A/D does not support `block_m == 64` and `block_n % 16 != 0` + if (block_m == 64 or block_n % 16 != 0) + return false; + + // Performance is lower with 1D1D and `block_m == 256` + if (kernel_type == KernelType::Kernel1D1D and major_b == cute::UMMA::Major::K and block_m != 128) + return false; + + // 1D2D kernels' maximum block N is 128 + // 1D2D kernels require more friendly block Ns + if (kernel_type == KernelType::Kernel1D2D and (block_n > 128 or 128 % block_n != 0)) + return false; + + // Check tensor memory validity + int sf_block_m = 0, sf_block_n = 0; + if (kernel_type == KernelType::Kernel1D1D) { + const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype); + sf_block_m = sf_block_m_, sf_block_n = sf_block_n_; + } + if (((2 * block_n) + (sf_block_m / 32) + (sf_block_n / 32)) > 512) + return false; + + // NOTES: when B is MN-major, we restrict `block_n` to multiples of 64, + // since TMA performance degrades when `swizzle_b <= 32B` (i.e., when `block_ns % 64 != 0`), even with 3D TMA + return major_b == cute::UMMA::Major::K or (block_n * c10::elementSize(ab_dtype)) % 64 == 0; + } + + static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const int& num_stages, + const int& block_m, const int& block_n, const int& block_k) { + return true; + } + + static bool should_minimize_num_sms() { + return false; + } + + static std::pair get_multicast_legality(const GemmType& gemm_type, + const int& m, const int& n, const int& block_m, const int& block_n, + const int& num_sms) { + // TODO: support other layouts + return { + is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous), + false, + }; + } + + static ThreadConfig get_thread_config(const KernelType& kernel_type, + const int& block_m, const int& block_n) { + return ThreadConfig::sm100(128, kernel_type == KernelType::Kernel1D2D ? block_m : 128); + } + + static int get_smem_cd_size(const KernelType& kernel_type, + const int& block_m, const int& block_n, + const int& swizzle_cd_mode, + const at::ScalarType& cd_dtype) { + constexpr static int layout_ad_m = 128; + return (kernel_type != KernelType::Kernel1D2D ? std::min(block_m, layout_ad_m) : block_m) * swizzle_cd_mode * 2; + } + + static std::pair get_sf_smem_size_per_stage(const KernelType& kernel_type, + const int& block_m, const int& block_n, const int& block_k, + const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + bool is_per_tensor = false) { + if (ab_dtype == torch::kBFloat16) + return {0, 0}; + + int smem_sfa_per_stage = 0; + int smem_sfb_per_stage = 0; + if (kernel_type == KernelType::Kernel1D1D) { + const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype); + smem_sfa_per_stage = sf_block_m * 4; + smem_sfb_per_stage = sf_block_n * 4; + } else { + smem_sfa_per_stage = block_m * 4; + smem_sfb_per_stage = 0; + } + return {smem_sfa_per_stage, smem_sfb_per_stage}; + } + + static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k, + const int& block_m, const int& block_n, const int& block_k, + bool is_per_tensor = false) { + return 0; + } + + static int get_barrier_smem_size(const int& num_stages) { + // TODO: remove SF barriers for BF16 GEMMs + // TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers + // NOTES: 1D2D kernel will not use the with-SF full barriers + // NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages + return num_stages * 8 * 3 + 2 * 8 * 2; + } + + static int get_tmem_ptr_smem_size() { + return 4; + } +}; + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp new file mode 100644 index 000000000..ef3cf012c --- /dev/null +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -0,0 +1,121 @@ +#pragma once + +#include +// Reuse some types in the JIT modules +#include + +#include "common.hpp" + +namespace deep_gemm { + +struct SM90ArchSpec { + static constexpr int smem_capacity = 232448; + + static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) { + return block_m; + } + + static int get_ab_load_block_n(const MulticastConfig& multicast_config, const int& block_n) { + return block_n; + } + + static int get_cd_store_block_m(const int& block_m) { + return block_m; + } + + static int get_cd_store_block_n(const int& block_n) { + return block_n; + } + + static bool is_block_size_legal(const KernelType& kernel_type, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const int& block_m, const int& block_n, const int& block_k) { + // FP32 output does not support `block_m == 256` + if (cd_dtype == at::kFloat and block_m == 256) + return false; + + // TODO: more general block N selection + // Must be some fixed block N selections + if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 and block_n != 152)) + return false; + + // Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k` + // Or too many register spills + if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192)) + return false; + + // Avoid bank conflicts for FP32 output + if (cd_dtype == torch::kFloat and block_n % 16 == 0) + return false; + + // The block sizes cannot be too large (for enough registers), so at least one dim less than 128 + return block_m <= 128 or block_n <= 128; + } + + static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const int& num_stages, + const int& block_m, const int& block_n, const int& block_k) { + // Unrolling both stages and `num_former_iters` will cause large code size + if (ab_dtype == torch::kFloat8_e4m3fn and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4) + return num_stages <= 4; + return true; + } + + static bool should_minimize_num_sms() { + return true; + } + + static std::pair get_multicast_legality(const GemmType& gemm_type, + const int& m, const int& n, const int& block_m, const int& block_n, + const int& num_sms) { + return { + is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked), + is_multicast_legal(m, block_m, 2, num_sms, false) and gemm_type != GemmType::MGroupedMasked, + }; + } + + static ThreadConfig get_thread_config(const KernelType& kernel_type, + const int& block_m, const int& block_n) { + return ThreadConfig::sm90(128, (block_m == 64 ? 1 : 2) * 128); + } + + static int get_smem_cd_size(const KernelType& kernel_type, + const int& block_m, const int& block_n, + const int& swizzle_cd_mode, const at::ScalarType& cd_dtype) { + return block_m * block_n * static_cast(c10::elementSize(cd_dtype)); + } + + static std::pair get_sf_smem_size_per_stage(const KernelType& kernel_type, + const int& block_m, const int& block_n, const int& block_k, + const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + bool is_per_tensor = false) { + if (ab_dtype == torch::kBFloat16) + return {0, 0}; + + int smem_sfa_per_stage = is_per_tensor ? 0 : block_m * static_cast(sizeof(float)); + int smem_sfb_per_stage = 0; + // TODO: figure out here + if (kernel_type == KernelType::Kernel1D1D) + smem_sfb_per_stage = align(block_n * 4, block_k); + smem_sfb_per_stage = is_per_tensor ? 1 * 4 : smem_sfb_per_stage; + return {smem_sfa_per_stage, smem_sfb_per_stage}; + } + + static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k, + const int& block_m, const int& block_n, const int& block_k, bool is_per_tensor = false) { + const auto& use_uniform_sfb = (block_k % block_n == 0 || is_per_tensor) ? 1 : 2; + return align(ceil_div(k, block_k) * static_cast(sizeof(float)) * use_uniform_sfb, 8); + } + + static int get_barrier_smem_size(const int& num_stages) { + // For 1D1D kernels, there is an extra barrier for accumulation + return (num_stages + 1) * 8 * 2; + } + + static int get_tmem_ptr_smem_size() { + return 0; + } +}; + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp new file mode 100644 index 000000000..ed9c5305f --- /dev/null +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -0,0 +1,173 @@ +#pragma once + +#include +#include + +#include "../../utils/math.hpp" +#include "../../utils/exception.hpp" + +namespace deep_gemm { + +static std::pair get_inner_outer_dims(const cute::UMMA::Major& major, const int& k, const int& mn) { + return major == cute::UMMA::Major::K ? std::make_pair(k, mn) : std::make_pair(mn, k); +} + +static int get_non_contiguous_dim(const cute::UMMA::Major& major) { + return major == cute::UMMA::Major::K ? -2 : -1; +} + +static int get_compiled_dim(const int& dim, const char& name, const std::string& compiled_dims) { + for (const char& c: compiled_dims) { + if (name == c) + return dim; + } + return 0; +} + +static std::string to_string(const cute::UMMA::Major& major) { + switch (major) { + case cute::UMMA::Major::K: return "cute::UMMA::Major::K"; + case cute::UMMA::Major::MN: return "cute::UMMA::Major::MN"; + } + DG_HOST_UNREACHABLE("Unknown major"); +} + +static std::string to_string(const GemmType& type) { + switch (type) { + case GemmType::Normal: return "GemmType::Normal"; + case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous"; + case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked"; + case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous"; + } + DG_HOST_UNREACHABLE("Unknown GEMM type"); +} + +static std::string to_string(const at::ScalarType& dtype) { + switch (dtype) { + case torch::kInt: return "int"; + case torch::kFloat: return "float"; + case torch::kBFloat16: return "cutlass::bfloat16_t"; + default: DG_HOST_UNREACHABLE("Unsupported dtype"); + } +} + +static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype) { + switch (dtype) { + case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32; + case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8; + default: DG_HOST_UNREACHABLE("Unsupported dtype"); + } +} + +static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode) { + switch (mode) { + case 0: return CU_TENSOR_MAP_SWIZZLE_NONE; + case 16: return CU_TENSOR_MAP_SWIZZLE_NONE; + case 32: return CU_TENSOR_MAP_SWIZZLE_32B; + case 64: return CU_TENSOR_MAP_SWIZZLE_64B; + case 128: return CU_TENSOR_MAP_SWIZZLE_128B; + default: DG_HOST_UNREACHABLE("Unsupported swizzling mode"); + } +} + +static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, + int gmem_inner_dim, int gmem_outer_dim, + int smem_inner_dim, int smem_outer_dim, + const int& gmem_outer_stride, + const int& swizzle_mode) { + const auto& elem_size = static_cast(t.element_size()); + if (swizzle_mode != 0) + smem_inner_dim = swizzle_mode / elem_size; + + CUtensorMap tensor_map; + const cuuint64_t gmem_dims[2] = {static_cast(gmem_inner_dim), static_cast(gmem_outer_dim)}; + const cuuint32_t smem_dims[2] = {static_cast(smem_inner_dim), static_cast(smem_outer_dim)}; + const cuuint64_t gmem_strides[1] = {static_cast(gmem_outer_stride * elem_size), }; + const cuuint32_t elem_strides[2] = {1, 1}; + if (get_env("DG_JIT_DEBUG")) { + printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d, elem size: %d\n", + gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim, + gmem_outer_stride, swizzle_mode, elem_size); + } + DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type()), + 2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides, + CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode), + CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensor_map; +} + +static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major, + const torch::Tensor& t, + const int& shape_m, const int& shape_k, + const int& block_m, const int& block_k, + const int& outer_stride, + const int& num_groups, + const int& swizzle_mode) { + if (num_groups > 1) + DG_HOST_ASSERT(major == cute::UMMA::Major::K); + const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups); + const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_m); + return make_tma_2d_desc(t, + gmem_inner_dim, gmem_outer_dim, + smem_inner_dim, smem_outer_dim, + outer_stride, + swizzle_mode); +} + +static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major, + const torch::Tensor& t, + const int& shape_n, const int& shape_k, + const int& block_n, const int& block_k, + const int& outer_stride, + const int& num_groups, + const int& swizzle_mode) { + const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n); + const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n); + + // `num_groups` is always applied into the outer dimensions + return make_tma_2d_desc(t, + gmem_inner_dim, gmem_outer_dim * num_groups, + smem_inner_dim, smem_outer_dim, + outer_stride, + swizzle_mode); +} + +static CUtensorMap make_tma_cd_desc(const torch::Tensor& t, + const int& shape_m, const int& shape_n, + const int& block_m, const int& block_n, + const int& outer_stride, + const int& num_groups, + const int& swizzle_mode) { + + // Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode` + // bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required + return make_tma_2d_desc(t, + shape_n, shape_m * num_groups, + block_n, block_m, + outer_stride, + swizzle_mode); +} + +static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, + const torch::Tensor& t, + int shape_mn, int shape_k, + const int& block_mn, const int& block_k, + const int& num_groups, + const int& swizzle_mode) { + DG_HOST_ASSERT(major == cute::UMMA::Major::MN); + + // TODO: maybe swizzle SF as well + DG_HOST_ASSERT(swizzle_mode == 0); + + shape_mn = get_tma_aligned_size(shape_mn, static_cast(t.element_size())); + return make_tma_2d_desc(t, + shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups, + block_mn, 1, + shape_mn, + swizzle_mode); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp new file mode 100644 index 000000000..033a7b753 --- /dev/null +++ b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp @@ -0,0 +1,143 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100BF16GemmRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void* grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_c; + CUtensorMap tensor_map_d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_bf16_gemm_impl< + {}, {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, {}, + {}, {}, + {}, {}, + {}, + {}, {}, {}, + {} + >); +}}; +)", + to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.num_groups, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, + to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype), + args.gemm_config.tc_util); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.grouped_layout, args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_c, args.tensor_map_d)); + } +}; + +static void sm100_bf16_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + // TODO: test other Ks + DG_HOST_ASSERT(k % 64 == 0); + const auto& config = get_best_config( + GemmType::Normal, KernelType::KernelNoSF, + m, n, k, 1, major_a, major_b, + torch::kBFloat16, d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + const auto& cd = c.value_or(d); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_c = make_tma_cd_desc(cd, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(cd.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + + // Duplicate the accumulator if necessary + if (c.has_value()) { + if (c->data_ptr() == d.data_ptr()) { + DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); + } else { + // ReSharper disable once CppExpressionWithoutSideEffects + d.copy_(c.value()); + } + } + + // Launch + const SM100BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_c = tensor_map_c, + .tensor_map_d = tensor_map_d + }; + const auto& code = SM100BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bf16_gemm", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp new file mode 100644 index 000000000..67272d9c7 --- /dev/null +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -0,0 +1,344 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void* grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_sfa; + CUtensorMap tensor_map_sfb; + CUtensorMap tensor_map_c; + CUtensorMap tensor_map_d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp8_gemm_1d1d_impl< + {}, {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, {}, + {}, {}, + {}, {}, + {}, + {}, {}, {} + >); +}}; +)", + to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.num_groups, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, + to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.grouped_layout, args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_sfa, args.tensor_map_sfb, + args.tensor_map_c, args.tensor_map_d)); + } +}; + +static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::Normal, KernelType::Kernel1D1D, + m, n, k, 1, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + const auto& cd = c.value_or(d); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_c = make_tma_cd_desc(cd, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(cd.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, config.block_k, 1, 0); + + // Duplicate the accumulator if necessary + if (c.has_value()) { + if (c->data_ptr() == d.data_ptr()) { + DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); + } else { + // ReSharper disable once CppExpressionWithoutSideEffects + d.copy_(c.value()); + } + } + + // Launch + const SM100FP8Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_c = tensor_map_c, + .tensor_map_d = tensor_map_d + }; + const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); + SM100FP8Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::MGroupedContiguous, KernelType::Kernel1D1D, + m, n, k, 1, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, config.block_k, num_groups, 0); + + // Launch kernel + const SM100FP8Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = m_indices.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_c = tensor_map_d, + .tensor_map_d = tensor_map_d + }; + const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code); + SM100FP8Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::Kernel1D1D, + expected_m, n, k, num_groups, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, num_groups, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, config.block_k, num_groups, 0); + + // Launch kernel + const SM100FP8Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_c = tensor_map_d, + .tensor_map_d = tensor_map_d + }; + const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code); + SM100FP8Gemm1D1DRuntime::launch(runtime, args); +} + +static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); + + int sum_k = 0, sum_sf_k = 0; + for (const auto& k: ks) { + sum_k += k, sum_sf_k += ceil_div(k, 512); + DG_HOST_ASSERT(k % 128 == 0); + } + const auto& num_groups = static_cast(ks.size()); + + // Get config using max K for better performance + const auto& max_k = *std::max_element(ks.begin(), ks.end()); + const auto& config = get_best_config( + GemmType::KGroupedContiguous, KernelType::Kernel1D1D, + m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, + torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& cd = c.value_or(d); + const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(0)), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(0)), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(1)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_c = make_tma_cd_desc(cd, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(cd.stride(1)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 512, + config.block_m, config.block_k, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 512, + config.block_n, config.block_k, 1, 0); + + // Duplicate the accumulator if necessary + if (c.has_value()) { + DG_HOST_ASSERT(c->data_ptr() == d.data_ptr()); + DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); + } + + // Launch kernel + const SM100FP8Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = sum_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_c = tensor_map_c, + .tensor_map_d = tensor_map_d + }; + const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code); + SM100FP8Gemm1D1DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp new file mode 100644 index 000000000..727d1b747 --- /dev/null +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp @@ -0,0 +1,237 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100FP8Gemm1D2DRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void *sfb, *grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + CUtensorMap tensor_map_sfa; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp8_gemm_1d2d_impl< + {}, {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, {}, + {}, {}, + {}, {}, + {}, + {}, {} + >); +}}; +)", + to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.num_groups, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, + to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.sfb, args.grouped_layout, + args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_d, args.tensor_map_sfa)); + } +}; + +static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(not c.has_value()); + + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::Normal, KernelType::Kernel1D2D, + m, n, k, 1, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + + // Launch + const SM100FP8Gemm1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM100FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_gemm_1d2d", code); + SM100FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::MGroupedContiguous, KernelType::Kernel1D2D, + m, n, k, 1, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms()); + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + + // Launch + const SM100FP8Gemm1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = m_indices.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM100FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d2d", code); + SM100FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::Kernel1D2D, + expected_m, n, k, num_groups, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms()); + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, num_groups, 0); + + // Launch + const SM100FP8Gemm1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM100FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d2d", code); + SM100FP8Gemm1D2DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp new file mode 100644 index 000000000..ea29883cf --- /dev/null +++ b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp @@ -0,0 +1,229 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90BF16GemmRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void *grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_bf16_gemm_impl< + {}, {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, {}, + {}, {} + >); +}}; +)", + // TODO: add CD dtype + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.num_groups, + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.grouped_layout, + args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_d)); + } +}; + +static void sm90_bf16_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(k % 64 == 0); + + const auto& config = get_best_config( + GemmType::Normal, KernelType::KernelNoSF, + m, n, k, 1, major_a, major_b, + torch::kBFloat16, d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Requires no TMA splits + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM90BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_bf16_gemm", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(k % 64 == 0); + + const auto& config = get_best_config( + GemmType::MGroupedContiguous, KernelType::KernelNoSF, + m, n, k, 1, major_a, major_b, + torch::kBFloat16, d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Requires no TMA splits + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM90BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = m_indices.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(k % 64 == 0); + + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::KernelNoSF, + expected_m, n, k, num_groups, major_a, major_b, + torch::kBFloat16, d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Requires no TMA splits + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM90BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp new file mode 100644 index 000000000..efb263eae --- /dev/null +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -0,0 +1,417 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void *sfb, *grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + CUtensorMap tensor_map_sfa; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_fp8_gemm_1d2d_impl< + {}, {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, {}, + {}, {} + >); +}}; +)", + // TODO: add CD dtype + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.num_groups, + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.sfb, args.grouped_layout, + args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_d, args.tensor_map_sfa)); + } +}; + +class SM90FP8GemmPerTensor1D2DRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void *sfb, *grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_fp8_gemm_per_tensor_1d2d_impl< + {}, {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, {}, + {}, {} + >); +}}; +)", + // TODO: add CD dtype + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.num_groups, + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.sfb, args.grouped_layout, + args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_d)); + } +}; + +static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::Normal, KernelType::Kernel1D2D, + m, n, k, 1, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::MGroupedContiguous, KernelType::Kernel1D2D, + m, n, k, 1, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = m_indices.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm90_m_grouped_fp8_gemm_contiguous_per_tensor_1d2d(const torch::Tensor& a, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::MGroupedContiguous, KernelType::Kernel1D2D, + m, n, k, 1, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms(), true); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM90FP8GemmPerTensor1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = m_indices.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + }; + const auto& code = SM90FP8GemmPerTensor1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_per_tensor_1d2d", code); + SM90FP8GemmPerTensor1D2DRuntime::launch(runtime, args); +} + +static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& aligned_k = align(k, 128); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::Kernel1D2D, + expected_m, n, k, num_groups, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, num_groups, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm90_m_grouped_fp8_gemm_masked_per_tensor_1d2d(const torch::Tensor& a, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& aligned_k = align(k, 128); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::Kernel1D2D, + expected_m, n, k, num_groups, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms(), true); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM90FP8GemmPerTensor1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + }; + const auto& code = SM90FP8GemmPerTensor1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_per_tensor_1d2d", code); + SM90FP8GemmPerTensor1D2DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/smxx_layout.hpp b/csrc/jit_kernels/impls/smxx_layout.hpp new file mode 100644 index 000000000..d8a60de98 --- /dev/null +++ b/csrc/jit_kernels/impls/smxx_layout.hpp @@ -0,0 +1,263 @@ +#pragma once + +#include + +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../../utils/layout.hpp" + +namespace deep_gemm { + +class TransposeFP32Runtime final: public LaunchRuntime { +public: + struct Args { + int mn, sf_k; + int block_mn; + void *sf, *out; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&transpose_fp32< + {}, {}, {} + >); +}}; +)", args.launch_args.num_threads, args.block_mn, args.sf_k); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, static_cast(args.mn))); + } +}; + +class TransposeAndPackFP32IntoUE8M0Runtime final: public LaunchRuntime { +public: + struct Args { + int mn, sf_k; + int block_mn; + void *sf, *out; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&transpose_and_pack_fp32_into_ue8m0< + {}, {}, {} + >); +}}; +)", args.launch_args.num_threads, args.block_mn, args.sf_k); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, static_cast(args.mn))); + } +}; + +class PackFP32IntoUE8M0Runtime final: public LaunchRuntime { +public: + struct Args { + int num_groups, mn, sf_k, packed_sf_k; + int block_mn, block_packed_sf_k; + void *sf, *out, *ks; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&pack_fp32_into_ue8m0< + {}, {}, {}, {} + >); +}}; +)", args.num_groups, args.launch_args.num_threads, args.block_mn, args.block_packed_sf_k); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k)); + } +}; + +static std::tuple preprocess_sf(const torch::Tensor& sf) { + // NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA + const auto& dim = sf.dim(); + DG_HOST_ASSERT(dim == 2 or dim == 3); + DG_HOST_ASSERT(sf.scalar_type() == torch::kFloat); + const auto& batched_sf = dim == 2 ? sf.unsqueeze(0) : sf; + + const auto& [num_groups, mn, sf_k] = get_shape<3>(batched_sf); + const auto& tma_aligned_mn = get_tma_aligned_size(mn, static_cast(sf.element_size())); + return {dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf}; +} + +static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) { + const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf); + + // The last kernel already gives a column-major TMA aligned layout + if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn) + return (dim == 2) ? batched_sf.squeeze(0) : batched_sf; + + const auto& out = torch::empty_strided({num_groups, mn, sf_k}, + {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, + batched_sf.options()); + + if (not batched_sf.is_contiguous()) { + // Fallback to PyTorch's slow copy if not contiguous + // ReSharper disable once CppExpressionWithoutSideEffects + out.copy_(batched_sf); + } else { + constexpr int block_mn = 64; + constexpr int num_threads = 512; + const auto& smem_size = block_mn * (sf_k + (1 - (sf_k % 2))) * static_cast(sizeof(float)); + const TransposeFP32Runtime::Args& args = { + .mn = mn, + .sf_k = sf_k, + .block_mn = block_mn, + .sf = batched_sf.data_ptr(), + .out = out.data_ptr(), + .launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, smem_size) + }; + + const auto& code = TransposeFP32Runtime::generate(args); + const auto& runtime = compiler->build("transpose_fp32", code); + TransposeFP32Runtime::launch(runtime, args); + } + return (dim == 2) ? out.squeeze(0) : out; +} + +static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const torch::Tensor& sf) { + const auto& sf_reshaped = (sf.dim() == 2) ? sf.unsqueeze(0) : sf; + + // First, convert into UE8M0 `uint8_t` + const auto& ue8m0_tensor = sf_reshaped.view(torch::kInt32).bitwise_right_shift(23).to(torch::kUInt8); + + // Second, make padded packed tensors + const auto& [num_groups, mn, k] = get_shape<3>(sf_reshaped); + const auto& aligned_mn = get_tma_aligned_size(mn, 4); + const auto& aligned_k = align(k, 4); + + const auto& options = torch::TensorOptions().device(sf.device()).dtype(torch::kUInt8); + auto padded = torch::zeros({num_groups, aligned_mn, aligned_k}, options); + // ReSharper disable once CppExpressionWithoutSideEffects + padded.slice(1, 0, mn).slice(2, 0, k).copy_(ue8m0_tensor); + padded = padded.view(-1).view(torch::kInt32).view({num_groups, aligned_mn, aligned_k / 4}); + + // Finally, transpose + auto out = torch::empty_strided({num_groups, aligned_mn, aligned_k / 4}, + {aligned_mn * (aligned_k / 4), 1, aligned_mn}, + at::TensorOptions().device(sf.device()).dtype(torch::kInt32)); + out = out.copy_(padded).slice(1, 0, mn); + return (sf.dim() == 2) ? out.squeeze(0) : out; +} + +static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) { + const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf); + const auto& packed_sf_k = ceil_div(sf_k, 4); + const auto& out = torch::empty_strided({num_groups, mn, packed_sf_k}, + {packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn}, + at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt)); + // Launch the kernel + if (batched_sf.is_contiguous()) { + if ((mn * sf_k) % 4 != 0 and num_groups > 1) + return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf); + + constexpr int block_mn = 48; + constexpr int num_threads = 512; + const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = { + .mn = mn, + .sf_k = sf_k, + .block_mn = block_mn, + .sf = batched_sf.data_ptr(), + .out = out.data_ptr(), + .launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, block_mn * sf_k * 4) + }; + + const auto& code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args); + const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code); + TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args); + } else { + if (mn % 4 != 0 or num_groups > 1) + return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf); + DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn); + + constexpr int block_mn = 128; + constexpr int block_packed_sf_k = 16; + constexpr int num_threads = 512; + const PackFP32IntoUE8M0Runtime::Args& args = { + .num_groups = 1, + .mn = mn, + .sf_k = sf_k, + .packed_sf_k = packed_sf_k, + .block_mn = block_mn, + .block_packed_sf_k = block_packed_sf_k, + .sf = batched_sf.data_ptr(), + .out = out.data_ptr(), + .ks = nullptr, + .launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads) + }; + + const auto& code = PackFP32IntoUE8M0Runtime::generate(args); + const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code); + PackFP32IntoUE8M0Runtime::launch(runtime, args); + } + return (dim == 2) ? out.squeeze(0) : out; +} + +static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf, + const torch::Tensor& ks_tensor, + const std::vector& ks) { + const auto& [sf_k, mn] = get_shape<2>(sf); + const auto& num_groups = static_cast(ks.size()); + + int ref_sf_k = 0, packed_sf_k = 0; + for (const auto& k: ks) + ref_sf_k += ceil_div(k, 128), packed_sf_k += ceil_div(k, 512); + DG_HOST_ASSERT(sf.is_contiguous()); + DG_HOST_ASSERT(ref_sf_k == sf_k); + DG_HOST_ASSERT(num_groups <= 128 and mn % 4 == 0); + + const auto& out = torch::empty({packed_sf_k, mn}, at::TensorOptions().device(sf.device()).dtype(torch::kInt)); + + constexpr int block_mn = 128; + constexpr int block_packed_sf_k = 16; + constexpr int num_threads = 512; + const PackFP32IntoUE8M0Runtime::Args& args = { + .num_groups = num_groups, + .mn = mn, + .sf_k = sf_k, + .packed_sf_k = packed_sf_k, + .block_mn = block_mn, + .block_packed_sf_k = block_packed_sf_k, + .sf = sf.data_ptr(), + .out = out.data_ptr(), + .ks = ks_tensor.data_ptr(), + .launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads) + }; + + const auto& code = PackFP32IntoUE8M0Runtime::generate(args); + const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code); + PackFP32IntoUE8M0Runtime::launch(runtime, args); + return out; +} + +} // namespace deep_gemm diff --git a/csrc/python_api.cpp b/csrc/python_api.cpp new file mode 100644 index 000000000..d4b210a22 --- /dev/null +++ b/csrc/python_api.cpp @@ -0,0 +1,19 @@ +#include +#include + +#include "apis/gemm.hpp" +#include "apis/layout.hpp" +#include "apis/runtime.hpp" + +#ifndef TORCH_EXTENSION_NAME +#define TORCH_EXTENSION_NAME deep_gemm_cpp +#endif + +// ReSharper disable once CppParameterMayBeConstPtrOrRef +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "DeepGEMM C++ library"; + + deep_gemm::gemm::register_apis(m); + deep_gemm::layout::register_apis(m); + deep_gemm::runtime::register_apis(m); +} diff --git a/csrc/utils/exception.hpp b/csrc/utils/exception.hpp new file mode 100644 index 000000000..57cc51377 --- /dev/null +++ b/csrc/utils/exception.hpp @@ -0,0 +1,75 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm { + +class DGException final : public std::exception { + std::string message = {}; + +public: + explicit DGException(const char *name, const char* file, const int line, const std::string& error) { + message = std::string(name) + " error (" + file + ":" + std::to_string(line) + "): " + error; + } + + const char *what() const noexcept override { + return message.c_str(); + } +}; + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +#ifndef DG_HOST_ASSERT +#define DG_HOST_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + throw DGException("Assertion", __FILE__, __LINE__, #cond); \ + } \ +} while (0) +#endif + +#ifndef DG_HOST_UNREACHABLE +#define DG_HOST_UNREACHABLE(reason) (throw DGException("Assertion", __FILE__, __LINE__, reason)) +#endif + +#ifndef DG_NVRTC_CHECK +#define DG_NVRTC_CHECK(cmd) \ +do { \ + const auto& e = (cmd); \ + if (e != NVRTC_SUCCESS) { \ + throw DGException("NVRTC", __FILE__, __LINE__, nvrtcGetErrorString(e)); \ + } \ +} while (0) +#endif + +#ifndef DG_CUDA_DRIVER_CHECK +#define DG_CUDA_DRIVER_CHECK(cmd) \ +do { \ + const auto& e = (cmd); \ + if (e != CUDA_SUCCESS) { \ + std::stringstream ss; \ + const char *name, *info; \ + cuGetErrorName(e, &name), cuGetErrorString(e, &info); \ + ss << static_cast(e) << " (" << name << ", " << info << ")"; \ + throw DGException("CUDA driver", __FILE__, __LINE__, ss.str()); \ + } \ +} while (0) +#endif + +#ifndef DG_CUDA_RUNTIME_CHECK +#define DG_CUDA_RUNTIME_CHECK(cmd) \ +do { \ + const auto& e = (cmd); \ + if (e != cudaSuccess) { \ + std::stringstream ss; \ + ss << static_cast(e) << " (" << cudaGetErrorName(e) << ", " << cudaGetErrorString(e) << ")"; \ + throw DGException("CUDA runtime", __FILE__, __LINE__, ss.str()); \ + } \ +} while (0) +#endif + +} // namespace deep_gemm diff --git a/csrc/utils/format.hpp b/csrc/utils/format.hpp new file mode 100644 index 000000000..bf617372b --- /dev/null +++ b/csrc/utils/format.hpp @@ -0,0 +1,6 @@ +#pragma once + +// Just a wrapper for the `fmt` headers +#define FMT_HEADER_ONLY +#include +#include diff --git a/csrc/utils/hash.hpp b/csrc/utils/hash.hpp new file mode 100644 index 000000000..fad1231f6 --- /dev/null +++ b/csrc/utils/hash.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include + +namespace deep_gemm { + +static uint64_t fnv1a(const std::string& data, const uint64_t& seed) { + uint64_t h = seed; + const uint64_t& prime = 0x100000001b3ull; + for (const char& c: data) { + h ^= static_cast(c); + h *= prime; + } + return h; +} + +static std::string get_hex_digest(const std::string& data) { + const auto& state_0 = fnv1a(data, 0xc6a4a7935bd1e995ull); + const auto& state_1 = fnv1a(data, 0x9e3779b97f4a7c15ull); + + // Split-mix 64 + const auto& split_mix = [](uint64_t z) { + z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ull; + z = (z ^ (z >> 27)) * 0x94d049bb133111ebull; + return z ^ (z >> 31); + }; + + std::ostringstream oss; + oss << std::hex << std::setfill('0') + << std::setw(16) << split_mix(state_0) + << std::setw(16) << split_mix(state_1); + return oss.str(); +} + +} // namespace deep_gemm diff --git a/csrc/utils/layout.hpp b/csrc/utils/layout.hpp new file mode 100644 index 000000000..a5dcd5062 --- /dev/null +++ b/csrc/utils/layout.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include +#include + +#include "math.hpp" +#include "exception.hpp" +#include "../jit/device_runtime.hpp" + +namespace deep_gemm { + +// Major-ness stuffs +static void major_check(const torch::Tensor& t) { + const auto dim = t.dim(); + DG_HOST_ASSERT(dim == 2 or dim == 3); + if (dim == 3) + DG_HOST_ASSERT(t.stride(0) == t.size(-2) * t.size(-1)); + DG_HOST_ASSERT(t.stride(-2) == 1 or t.stride(-1) == 1); +} + +static cute::UMMA::Major get_major_type_ab(const torch::Tensor& t) { + major_check(t); + return t.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; +} + +static void check_major_type_cd(const torch::Tensor& t) { + // NOTES: the library only supports row-major output layouts + major_check(t); + DG_HOST_ASSERT(t.stride(-1) == 1); +} + +static bool fp8_requires_k_major() { + return device_runtime->get_arch_major() == 9; +} + +// Tensor utils +template +static auto get_shape(const torch::Tensor& t) { + return [&t] (std::index_sequence) { + return std::make_tuple(static_cast(t.sizes()[Is])...); + }(std::make_index_sequence()); +} + +// Recipe +static std::tuple +get_default_recipe(const torch::ScalarType& sfa_dtype, const torch::ScalarType& sfb_dtype) { + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + DG_HOST_ASSERT(sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat); + return {1, 128, 128}; + } else if (arch_major == 10) { + DG_HOST_ASSERT(sfb_dtype == torch::kFloat or sfb_dtype == torch::kInt); + return sfb_dtype == torch::kFloat ? + std::make_tuple(1, 128, 128): // Legacy format or 1D2D kernels + std::make_tuple(1, 1, 128); // 1D1D kernels + } + DG_HOST_UNREACHABLE("Unknown recipe"); +} + +// SF layouts +static torch::Tensor check_sf_layout(const torch::Tensor& sf, + const int& mn, const int& k, + const int& gran_mn, const int& gran_k, + const std::optional& num_groups, + const bool& tma_stride_check = false, + const bool& contiguous_check = false, + const std::optional& type_check = std::nullopt, + const bool& is_per_tensor = false) { + // Type check + if (type_check.has_value()) + DG_HOST_ASSERT(sf.scalar_type() == type_check.value()); + + // Always do shape checks + const auto& sf_dtype = sf.scalar_type(); + DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt); + if (is_per_tensor) + DG_HOST_ASSERT(sf.dim() == static_cast(1)); + else + DG_HOST_ASSERT(sf.dim() == static_cast(num_groups.has_value()) + 2); + if (num_groups.has_value() && !is_per_tensor) + DG_HOST_ASSERT(sf.size(-3) == num_groups.value()); + if (!is_per_tensor) { + DG_HOST_ASSERT(sf.size(-2) == ceil_div(mn, gran_mn)); + DG_HOST_ASSERT(sf.size(-1) == ceil_div(k, gran_k * (sf_dtype == torch::kFloat ? 1 : 4))); + } + + // TMA stride checks: TMA aligned and MN-major + if (tma_stride_check && !is_per_tensor) { + if (num_groups.has_value()) + DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1)); + DG_HOST_ASSERT(sf.stride(-2) == 1); + DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size())); + } + + // Hopper SFB must be contiguous + if (contiguous_check) + DG_HOST_ASSERT(sf.is_contiguous()); + return sf; +} + +// Value matrix layout +static int get_mk_alignment_for_contiguous_layout() { + return 128; +} + +} // namespace deep_gemm diff --git a/csrc/utils/lazy_init.hpp b/csrc/utils/lazy_init.hpp new file mode 100644 index 000000000..386b1b453 --- /dev/null +++ b/csrc/utils/lazy_init.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +#define DG_DECLARE_STATIC_VAR_IN_CLASS(cls, name) decltype(cls::name) cls::name + +namespace deep_gemm { + +template +class LazyInit { +public: + explicit LazyInit(std::function()> factory) + : factory(std::move(factory)) {} + + T* operator -> () { + if (ptr == nullptr) + ptr = factory(); + return ptr.get(); + } + +private: + std::shared_ptr ptr; + std::function()> factory; +}; + +} // namespace deep_gemm diff --git a/csrc/utils/math.hpp b/csrc/utils/math.hpp new file mode 100644 index 000000000..264d2d104 --- /dev/null +++ b/csrc/utils/math.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include "exception.hpp" + +namespace deep_gemm { + +template +static T ceil_div(const T& a, const T& b) { + return (a + b - 1) / b; +} + +template +static constexpr T align(const T& a, const T& b) { + return ceil_div(a, b) * b; +} + +static int get_tma_aligned_size(const int& x, const int& element_size) { + constexpr int kNumTMAAlignmentBytes = 16; + DG_HOST_ASSERT(kNumTMAAlignmentBytes % element_size == 0); + return align(x, kNumTMAAlignmentBytes / element_size); +} + +} // namespace deep_gemm diff --git a/csrc/utils/system.hpp b/csrc/utils/system.hpp new file mode 100644 index 000000000..91dee1223 --- /dev/null +++ b/csrc/utils/system.hpp @@ -0,0 +1,90 @@ +#pragma once + +#include +#include +#include +#include + +#include "exception.hpp" + +namespace deep_gemm { + +// ReSharper disable once CppNotAllPathsReturnValue +template +static dtype_t get_env(const std::string& name, const dtype_t& default_value = dtype_t()) { + const auto& c_str = std::getenv(name.c_str()); + if (c_str == nullptr) + return default_value; + + // Read the env and convert to the desired type + if constexpr (std::is_same_v) { + return std::string(c_str); + } else if constexpr (std::is_same_v) { + int value; + std::sscanf(c_str, "%d", &value); + return value; + } else { + DG_HOST_ASSERT(false and "Unexpected type"); + } +} + +static std::tuple call_external_command(std::string command) { + command = command + " 2>&1"; + const auto& deleter = [](FILE* f) { if (f) pclose(f); }; + std::unique_ptr pipe(popen(command.c_str(), "r"), deleter); + DG_HOST_ASSERT(pipe != nullptr); + + std::array buffer; + std::string output; + while (fgets(buffer.data(), buffer.size(), pipe.get())) + output += buffer.data(); + const auto& exit_code = WEXITSTATUS(pclose(pipe.release())); + return {exit_code, output}; +} + +static std::vector collect_files(const std::filesystem::path& root) { + std::vector files; + std::function impl; + impl = [&](const std::filesystem::path& dir) { + for (const auto& entry: std::filesystem::directory_iterator(dir)) { + if (entry.is_directory()) { + impl(entry.path()); + } else if (entry.is_regular_file() and entry.path().extension() == ".cuh") { + files.emplace_back(entry.path()); + } + } + }; + impl(root); + + // Be consistent + std::sort(files.begin(), files.end()); + return files; +} + +static std::filesystem::path make_dirs(const std::filesystem::path& path) { + // OK if existed + std::error_code capture; + const bool& created = std::filesystem::create_directories(path, capture); + DG_HOST_ASSERT(created or capture.value() == 0); + if (created and get_env("DG_JIT_DEBUG")) + printf("Create directory: %s\n", path.c_str()); + return path; +} + +static std::string get_uuid() { + static std::random_device rd; + static std::mt19937 gen([]() { + return rd() ^ std::chrono::steady_clock::now().time_since_epoch().count(); + }()); + static std::uniform_int_distribution dist; + + std::stringstream ss; + ss << getpid() << "-" + << std::hex << std::setfill('0') + << std::setw(8) << dist(gen) << "-" + << std::setw(8) << dist(gen) << "-" + << std::setw(8) << dist(gen); + return ss.str(); +} + +} // deep_gemm diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 8e6b29965..1200d22a3 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -1,15 +1,78 @@ -import torch - -from . import jit -from .jit_kernels import ( - gemm_fp8_fp8_bf16_nt, - m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, - m_grouped_gemm_fp8_fp8_bf16_nt_masked, - wgrad_gemm_fp8_fp8_fp32_nt, - k_grouped_wgrad_gemm_fp8_fp8_fp32_nt, - ceil_div, - set_num_sms, get_num_sms, - get_col_major_tma_aligned_tensor, - get_m_alignment_for_contiguous_layout +import os +import subprocess + +# Set some default environment provided at setup +try: + # noinspection PyUnresolvedReferences + from .envs import persistent_envs + for key, value in persistent_envs.items(): + if key not in os.environ: + os.environ[key] = value +except ImportError: + pass + +# Configs +import deep_gemm_cpp +from deep_gemm_cpp import ( + set_num_sms, + get_num_sms, + set_tc_util, + get_tc_util, +) + +# Kernels +from deep_gemm_cpp import ( + # FP8 GEMMs + fp8_gemm_nt, fp8_gemm_nn, + fp8_gemm_tn, fp8_gemm_tt, + m_grouped_fp8_gemm_nt_contiguous, + m_grouped_fp8_gemm_nn_contiguous, + m_grouped_fp8_gemm_nt_masked, + k_grouped_fp8_gemm_tn_contiguous, + # BF16 GEMMs + bf16_gemm_nt, bf16_gemm_nn, + bf16_gemm_tn, bf16_gemm_tt, + m_grouped_bf16_gemm_nt_contiguous, + m_grouped_bf16_gemm_nt_masked, + # Per Tensor GEMMs + m_grouped_fp8_gemm_nt_contiguous_per_tensor, + m_grouped_fp8_gemm_nt_masked_per_tensor, + # Layout kernels + transform_sf_into_required_layout +) + +# Some alias for legacy supports +# TODO: remove these later +fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_gemm_nt_masked +bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked +fp8_m_grouped_gemm_nt_masked_per_tensor = m_grouped_fp8_gemm_nt_masked_per_tensor + +# Some utils +from . import testing +from . import utils +from .utils import * + + +# Initialize CPP modules +def _find_cuda_home() -> str: + # TODO: reuse PyTorch API later + # For some PyTorch versions, the original `_find_cuda_home` will initialize CUDA, which is incompatible with process forks + cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') + if cuda_home is None: + # noinspection PyBroadException + try: + with open(os.devnull, 'w') as devnull: + nvcc = subprocess.check_output(['which', 'nvcc'], stderr=devnull).decode().rstrip('\r\n') + cuda_home = os.path.dirname(os.path.dirname(nvcc)) + except Exception: + cuda_home = '/usr/local/cuda' + if not os.path.exists(cuda_home): + cuda_home = None + assert cuda_home is not None + return cuda_home + + +deep_gemm_cpp.init( + os.path.dirname(os.path.abspath(__file__)), # Library root directory path + _find_cuda_home() # CUDA home ) -from .utils import bench, bench_kineto, calc_diff diff --git a/deep_gemm/include/deep_gemm/common/cute_tie.cuh b/deep_gemm/include/deep_gemm/common/cute_tie.cuh new file mode 100644 index 000000000..cd2aace7a --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/cute_tie.cuh @@ -0,0 +1,48 @@ +#pragma once + +namespace cute { + +struct ignore_t { + template + constexpr const ignore_t& operator=(T&&) const noexcept { + return *this; + } +}; + +inline constexpr ignore_t ignore{}; + +} // namespace cute + +#define CUTE_TIE_CONCAT_IMPL(A, B) A##B +#define CUTE_TIE_CONCAT(A, B) CUTE_TIE_CONCAT_IMPL(A, B) + +#define CUTE_TIE_GET_NTH_ARG(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N +#define CUTE_TIE_COUNT_ARGS(...) \ + CUTE_TIE_GET_NTH_ARG(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) + +#define CUTE_TIE_OP_DECL(I, TUPLE, VAR) auto VAR = ::cute::get(TUPLE) +#define CUTE_TIE_OP_ASSIGN(I, TUPLE, VAR) VAR = ::cute::get(TUPLE) + +#define CUTE_TIE_APPLY_OP_1(OP, T, V1) OP(0, T, V1); +#define CUTE_TIE_APPLY_OP_2(OP, T, V1, V2) OP(0, T, V1); OP(1, T, V2); +#define CUTE_TIE_APPLY_OP_3(OP, T, V1, V2, V3) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); +#define CUTE_TIE_APPLY_OP_4(OP, T, V1, V2, V3, V4) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); +#define CUTE_TIE_APPLY_OP_5(OP, T, V1, V2, V3, V4, V5) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); OP(4, T, V5); + +#define CUTE_TIE_DECL(TUPLE_EXPR, ...) \ + auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \ + CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \ + CUTE_TIE_OP_DECL, \ + CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \ + __VA_ARGS__ \ + ) + +#define CUTE_TIE(TUPLE_EXPR, ...) \ + do { \ + auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \ + CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \ + CUTE_TIE_OP_ASSIGN, \ + CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \ + __VA_ARGS__ \ + ); \ + } while (0) diff --git a/deep_gemm/include/deep_gemm/common/scheduler.cuh b/deep_gemm/include/deep_gemm/common/scheduler.cuh new file mode 100644 index 000000000..2324a9bf5 --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/scheduler.cuh @@ -0,0 +1,229 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +enum class KGroupedIndexType { + MN, + K, + SF_K, +}; + +template +static constexpr uint32_t get_num_1d_blocks_per_group() { + // Select the best from candidates + uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits::max(); + for (const auto& candidate: {8u, 16u}) { + const auto& usage = kIsMulticastOnA ? + candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; + } + + return num_best_blocks; +} + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" +template ()> +struct Scheduler { + int current_iter = -1; + + // Block configs + uint32_t num_blocks; + uint32_t num_m_blocks; + uint32_t num_n_blocks; + + // For SM90 multicast checks + uint32_t num_blocks_in_group; + bool is_peer_cta_alive = true; + + // For grouped GEMM + int* grouped_layout; + uint32_t current_group_idx; + // Only used for masked layout + uint32_t current_m_cumsum; + // Only used for k-grouped layout + uint32_t current_shape_k, current_num_valid_groups, current_k_cumsum, current_sf_k_cumsum; + + // ReSharper disable once CppPossiblyUninitializedMember + __device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, + int* grouped_layout = nullptr) { + num_m_blocks = ceil_div(shape_m, BLOCK_M); + num_n_blocks = ceil_div(shape_n, BLOCK_N); + if constexpr (kGemmType == GemmType::Normal) { + num_blocks = num_m_blocks * num_n_blocks; + } else if (kGemmType == GemmType::MGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + } else if (kGemmType == GemmType::MGroupedMasked) { + current_group_idx = current_m_cumsum = 0; + this->grouped_layout = grouped_layout; + } else if (kGemmType == GemmType::KGroupedContiguous) { + current_group_idx = current_num_valid_groups = 0; + current_k_cumsum = current_sf_k_cumsum = 0; + current_shape_k = __ldg(grouped_layout + current_group_idx); + this->grouped_layout = grouped_layout; + } + } + + __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + const auto& primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks; + const auto& secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks; + const auto& num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; + const auto& group_idx = block_idx / num_blocks_per_group; + auto first_block_idx = group_idx * kNum1DBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast + // NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast + // while SM100 uses 2-CTA, which can not be dynamically disabled +#if __CUDA_ARCH__ < 1000 + if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) { + if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { + num_blocks_in_group = num_blocks_in_group ^ 1; + } else { + in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; + first_block_idx += num_blocks_in_group ^ 1; + num_blocks_in_group = 1; + } + } +#endif + + // Convert to final M/N block indices + // `kIsMulticastOnA == true` leads to groups on N + if constexpr (kIsMulticastOnA) { + m_block_idx = in_group_idx / num_blocks_in_group; + n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + } else { + m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + n_block_idx = in_group_idx / num_blocks_in_group; + } + } + + template + __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx = 0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + const auto offset = kWithGroupOffset ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + auto offset = 0; + if constexpr (kWithGroupOffset) { + if constexpr (kIndexType == KGroupedIndexType::MN) + offset = current_group_idx * shape_dim; + else if constexpr (kIndexType == KGroupedIndexType::K) + offset = current_k_cumsum; + else if constexpr (kIndexType == KGroupedIndexType::SF_K) + offset = current_sf_k_cumsum; + } + return offset + block_idx * block_size; + } + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x; + + if constexpr (kGemmType == GemmType::MGroupedMasked) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + current_group_idx)), BLOCK_M); + const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * num_n_blocks) + break; + + // Move to check the next group + current_group_idx ++, current_m_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx); + } else if (kGemmType == GemmType::KGroupedContiguous) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + if (current_shape_k > 0 and next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks) + break; + + // Move to check the next group + if (current_shape_k > 0) { + current_k_cumsum += current_shape_k; + current_sf_k_cumsum += ceil_div(current_shape_k, 512u); + current_num_valid_groups ++; + } + if ((++ current_group_idx) != kNumGroups) + current_shape_k = __ldg(grouped_layout + current_group_idx); + } + + get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx); + } else { + if (next_block_idx >= num_blocks) + return false; + + // For SM90 only + // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned + is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass) + num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound + get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx); + } + return true; + } + + // For SM90 only + __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { + if (num_blocks_in_group == 1) + return false; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked) { + return true; + } else { + DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type"); + if constexpr (kIsMulticastOnA) { + return true; + } else { + const auto& group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); + const auto& peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); + return group_idx == peer_group_idx; + } + } + } + + // For SM90 only + // ReSharper disable once CppNotAllPathsReturnValue + __device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { + if constexpr (kGemmType == GemmType::Normal) { + return true; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx); + } + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/common/sm100_utils.cuh b/deep_gemm/include/deep_gemm/common/sm100_utils.cuh new file mode 100644 index 000000000..b208302f2 --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/sm100_utils.cuh @@ -0,0 +1,169 @@ +#pragma once + +#include +#include +#include + +#include + +namespace deep_gemm::sm100 { + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +__device__ __forceinline__ void +tma_copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr, + dtype_t* smem_ptr, const uint32_t& inner_idx, const int32_t& outer_idx) { + DG_STATIC_ASSERT(1 <= kNumMulticast and kNumMulticast <= 2, "Invalid multicast config"); + DG_STATIC_ASSERT(static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL) == + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint"); + + // 2-CTA function will send signals to the leader CTA only + const auto copy_func = kNumMulticast == 1 ? cute::SM90_TMA_LOAD_2D::copy : cute::SM100_TMA_2SM_LOAD_2D::copy; + + // Issue multiple TMAs + constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size(); + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + copy_func(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } +} + +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr, + uint32_t stride_byte_offset, uint32_t leading_byte_offset) { + cute::UMMA::SmemDescriptor desc; + + // Set the version for SM100 + desc.version_ = 1; + + // Legacy mode + desc.lbo_mode_ = 0; + + // Layout + desc.layout_type_ = static_cast(layout); + + // Start address + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); + + // Base offset + desc.base_offset_ = 0; + + // SBO and LBO + desc.stride_byte_offset_ = stride_byte_offset >> 4; + desc.leading_byte_offset_ = leading_byte_offset >> 4; + + return desc; +} + +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) { + // NOTES: the UTCCP layout is K-major by default + // Atom size: 8 x 128 bits + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero + return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0); +} + +__device__ __forceinline__ +void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) { + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::UMMA::LayoutType to_umma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B; + if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B; + if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B; +} + +template +__device__ __forceinline__ +constexpr uint32_t get_umma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +template +__device__ __forceinline__ +uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) { + return base + (((offset + k_idx * get_umma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_umma_desc_stride_k(); + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = 8 * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(to_umma_layout_type(), + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = 8 * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(to_umma_layout_type(), + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } +} + +__device__ __forceinline__ +uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sf_id) { + desc.a_sf_id_ = sf_id, desc.b_sf_id_ = sf_id; + return static_cast(static_cast(desc)) << 32; +} + +template +__device__ constexpr uint32_t get_num_aligned_tmem_cols() { + DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns"); + if (kNumCols <= 32) return 32; + if (kNumCols <= 64) return 64; + if (kNumCols <= 128) return 128; + if (kNumCols <= 256) return 256; + return 512; +} + +__device__ __forceinline__ void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +__device__ __forceinline__ void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +} // namespace `deep_gemm::sm100` diff --git a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh new file mode 100644 index 000000000..e590b4797 --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh @@ -0,0 +1,226 @@ +#pragma once + +#include +#include + +namespace deep_gemm::sm90 { + +template +struct FP8MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct FP8MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN(); + } + + static constexpr auto select_type() { + return FP8MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct BF16MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 16; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct BF16MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); + if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); + if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); + if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS(); + if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS(); + if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS(); + if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS(); + if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS(); + if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS(); + if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS(); + if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS(); + if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS(); + if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS(); + if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS(); + if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS(); + if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS(); + if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS(); + if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS(); + if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS(); + if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS(); + if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS(); + if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS(); + if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS(); + if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS(); + if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS(); + if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS(); + if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS(); + if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS(); + if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS(); + if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS(); + if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS(); + } + + static constexpr auto select_type() { + return BF16MMA(); + } + + using type = decltype(select_type()); +}; + + +template +struct SM90_U32x2_STSM_N { + __device__ __forceinline__ static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "l"(smem_dst), "r"(src[0]), "r"(src[1])); + } +}; + +__forceinline__ __device__ void warpgroup_arrive() { + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +__forceinline__ __device__ void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +__forceinline__ __device__ void warpgroup_fence_operand(float& reg) { + asm volatile("" : "+f"(reg) :: "memory"); +} + +template +__forceinline__ __device__ void warpgroup_wait() { + DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +} + +// TODO: replace with CUTLASS solution +union GmmaDescriptor { + __host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {} + + __host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {} + + __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {} + + __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {} + + __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept { + desc_ = t.desc_; + return *this; + } + + __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept { + desc_ = t.desc_; + return *this; + } + + uint64_t desc_; + uint32_t reg32_[2]; + uint16_t reg16_[4]; + + struct { + uint16_t start_address_: 14, : 2; + uint16_t leading_byte_offset_: 14, : 2; + uint16_t stride_byte_offset_: 14, : 2; + uint8_t : 1, base_offset_: 3, : 4; + uint8_t : 6, layout_type_: 2; + } bitfield; + + // Decay to an `uint64_t` + __host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; } +}; + +template +__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout_type, + const int& leading_byte_offset = 0, + const int& stride_byte_offset = 1024) { + GmmaDescriptor desc; + const auto& uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +__device__ __forceinline__ void +tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, + const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& num_tma_multicast = 1) { + constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); + if (num_tma_multicast == 1) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); + } else if (cute::block_rank_in_cluster() == 0) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1); + } +} + +} // namespace `deep_gemm::sm90` diff --git a/deep_gemm/include/deep_gemm/common/types.hpp b/deep_gemm/include/deep_gemm/common/types.hpp new file mode 100644 index 000000000..23e73424f --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/types.hpp @@ -0,0 +1,18 @@ +#pragma once + +namespace deep_gemm { + +enum class GemmType { + Normal = 0, + MGroupedContiguous = 1, + MGroupedMasked = 2, + KGroupedContiguous = 3, +}; + +enum class KernelType { + Kernel1D1D = 0, + Kernel1D2D = 1, + KernelNoSF = 2 +}; + +} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh new file mode 100644 index 000000000..fc84b696d --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -0,0 +1,165 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "cute_tie.cuh" + +#ifdef __CLION_IDE__ + +__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { + asm volatile("trap;"); +} + +#define printf host_device_printf +#endif + +#ifndef DG_DEVICE_ASSERT +#define DG_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ +} while (0) +#endif + +#ifndef DG_TRAP_ONLY_DEVICE_ASSERT +#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) +#endif + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +namespace deep_gemm { + +template +struct PatternVisitor { + FuncT func; + + __device__ __host__ + explicit PatternVisitor(FuncT&& func): func(std::forward(func)) {} + + __device__ __host__ + auto operator [](const uint32_t& i) { + return func(i); + } +}; + +template +__device__ __host__ T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ T align(T a, T b) { + return ceil_div(a, b) * b; +} + +template +__device__ __host__ constexpr T constexpr_align(T a, T b) { + return constexpr_ceil_div(a, b) * b; +} + +template +__device__ __host__ constexpr T constexpr_gcd(T a, T b) { + return b == 0 ? a : constexpr_gcd(b, a % b); +} + +template +__forceinline__ __device__ void swap(T& a, T& b) { + T temp = a; + a = b; + b = temp; +} + +__forceinline__ __device__ uint32_t get_sm_idx() { + uint32_t sm_idx; + asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); + return sm_idx; +} + +__forceinline__ __device__ uint32_t get_lane_idx() { + uint32_t lane_id; + asm ("mov.u32 %0, %laneid;" : "=r"(lane_id)); + return lane_id; +} + +__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ float4 ld_shared(const float4* ptr) { + float4 ret; + asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) { + uint4 ret; + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ float ld_shared(const float* ptr) { + float ret; + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ void st_shared(const float* ptr, float val) { + asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val)); +} + +__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { + asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val)); +} + +__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(ptr), "r"(x), "r"(y), "r"(z), "r"(w)); +} + +template +__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) { + auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); + return *reinterpret_cast(&bf16x2); +} + +__device__ __forceinline__ void prefetch_l1(void *ptr) { + asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); +} + +template +struct Vectorized { + static auto zeros() { + // TODO: add `ulonglong4` for SM100 once `__ldg` support this + if constexpr (kNumBytes > 0 and kNumBytes % 16 == 0) { + return make_uint4(0, 0, 0, 0); + } else if constexpr (kNumBytes > 0 and kNumBytes % 8 == 0) { + return make_uint2(0, 0); + } else if constexpr (kNumBytes > 0 and kNumBytes % 4 == 0) { + return 0; + } else { + DG_STATIC_ASSERT(kNumBytes > 0 and kNumBytes % 4 == 0, "Invalid vectorization"); + } + } + + using vec_t = decltype(zeros()); +}; + +} // namespace `deep_gemm` diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh deleted file mode 100644 index 5c11cd3dc..000000000 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ /dev/null @@ -1,444 +0,0 @@ -#pragma once - -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wunknown-attributes" - -#include -#include - -#include -#include -#include - -#include "mma_utils.cuh" -#include "scheduler.cuh" -#include "tma_utils.cuh" -#include "utils.cuh" - -namespace deep_gemm { - -template -__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) { - if (num_former_iters == kNumFormerIters) { - inner_launch_k_iterations(func, cute::Int{}); - return; - } - - if constexpr (kNumFormerIters + kGap <= kEnd) - outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); -} - -template -__global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) -fp8_gemm_kernel(float* scales_b, int* grouped_layout, - uint32_t shape_m, - const __grid_constant__ CUtensorMap tensor_map_a, - const __grid_constant__ CUtensorMap tensor_map_b, - const __grid_constant__ CUtensorMap tensor_map_scales_a, - const __grid_constant__ CUtensorMap tensor_map_d) { -#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) - // Scaling checks - DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); - - // Types - using WGMMA = typename FP8MMASelector::type; - using Barrier = cutlass::arch::ClusterTransactionBarrier; - DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); - - // Shared memory - static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); - static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16); - static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); - static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); - static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); - static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); - static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier); - - // Configs - constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; - constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); - constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; - constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages); - const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_id(); - - // Prefetch TMA descriptors at the very beginning - if (threadIdx.x == kNumMathThreads) { - // NOTES: `reinterpret_cast` must be here, or NVRTC will fail - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); - - // `tensor_map_d` is only used in swizzling mode - // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode - if constexpr (kSwizzleDMode > 0) - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); - } - __syncwarp(); - - // Align to 1024 bytes for swizzle-128B - extern __shared__ __align__(1024) uint8_t smem_buffer[]; - DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); - - // Data on shared memory - auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); - __nv_fp8_e4m3* smem_a[kNumStages]; - __nv_fp8_e4m3* smem_b[kNumStages]; - float* smem_scales_a[kNumStages]; - float* smem_scales_b; - - // TMA Barrier for both divisible and non-divisible cases - Barrier* full_barriers[kNumStages]; - Barrier* empty_barriers[kNumStages]; - - // Fill shared memory pointers - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); - smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); - smem_scales_a[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE); - } - smem_scales_b = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)); - - // Fill barriers - auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_scales_b) + SMEM_SCALES_B_SIZE); - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - full_barriers[i] = barrier_start_ptr + i; - empty_barriers[i] = barrier_start_ptr + kNumStages + i; - } - - // Initialize barriers - DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); - if (threadIdx.x == kNumMathThreads) { - // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, - // even with TMA multicast disabled, we want to make the behavior aligned - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - full_barriers[i]->init(1); - empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); - } - - // Make initialized barrier visible in async proxy - cutlass::arch::fence_view_async_shared(); - (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); - } - - // Synchronize all threads to make barrier visible in normal memory model - (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); - - // For pipeline unrolling - struct DivisibleK {}; - struct NotDivisibleK {}; - struct SkipComputation {}; - struct NotSkipComputation {}; - auto launch_k_iterations = [](const auto& func, bool skip_computation, uint32_t num_former_iters) { - constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; - constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; - constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; - - // NOTES: for too-many branches (> 5), we disable this optimization - // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value - outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) { - if (skip_computation) { - for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter) - func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type); - } else if (SHAPE_K % kFullKOfAllStages == 0) { - for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter) - func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); - } else { - for (uint32_t k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) - func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); - func(kNumIterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type); - } - }, func, kShouldOptimize ? num_former_iters : 0); - }; - - // Register reconfigurations - constexpr uint32_t kNumTMARegisters = 40; - constexpr uint32_t kNumMathRegisters = 232; - - // Block scheduler - uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, grouped_layout); - - if (threadIdx.x >= kNumMathThreads) { - // TMA warp-group for loading data - cutlass::arch::warpgroup_reg_dealloc(); - - // NOTES: only one thread (or warp) will be used - if (threadIdx.x == kNumMathThreads) { - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; - - // Assign TMA multicast number into A and B - // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. - const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); - const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - - // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all - // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait consumer release - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - - // Issue TMA A - auto& full_barrier = *full_barriers[s]; - uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), - num_tma_multicast_a); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K), - num_tma_multicast_a); - - // Issue TMA B - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx), - num_tma_multicast_b); - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - full_barriers[s]->arrive(); - } - }, false, 0); - } - - // To safely deconstruct distributed shared barriers, we need another round of empty waits - if constexpr (kNumTMAMulticast > 1) { - #pragma unroll - for (uint32_t s = 0; s < kNumStages; ++ s) - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); - } - } - } else { - // Math warp-groups for WGMMA - cutlass::arch::warpgroup_reg_alloc(); - - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); - const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; - - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - // Decide the number of scales B to load - DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); - uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; - if constexpr (not kMustUseUniformedScaleB) { - num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; - num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; - } - uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); - - // Load B scales with math warp-groups - // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks - if (threadIdx.x >= 32) { - auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); - auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; - #pragma unroll - for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) - st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); - } - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Accumulation for WGMMA or CUDA promotion - constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M); - DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; - - // Empty barrier arrival - auto empty_barrier_arrive = [&](uint32_t s) { - if constexpr (kNumTMAMulticast == 1) { - lane_idx == 0 ? empty_barriers[s]->arrive() : void(); - } else { - auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); - lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); - } - }; - - // Launch MMAs - launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) { - constexpr bool kSkipComputation = std::is_same_v; - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : - (kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K); - - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Read B scales - float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; - // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks - if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); - - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - - // TODO: remove some useless computation for unaligned Ms - #pragma unroll - for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { - auto m_offset = local_idx * WAVE_BLOCK_M; - - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); - auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); - - // Commit WGMMA instructions - #pragma unroll - for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival at the last warpgroup wave - if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) - empty_barrier_arrive(s); - - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - - auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; - #pragma unroll - for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; - } - } - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - empty_barrier_arrive(s); - } - }, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters); - - // TMA checks - constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); - constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); - constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; - DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); - DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, - "Unaligned TMA store or too many TMA store instructions"); - DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); - DG_STATIC_ASSERT(static_cast(kSwizzleDMode > 0) + static_cast(BLOCK_N_PADDING > 0) <= 1, - "Swizzling and padding are not compatible"); - - // Wait last TMA store to be finished - if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) - cute::tma_store_wait<0>(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Write back to shared memory using STSM and issue TMA stores - DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); - #pragma unroll - for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { - auto m_offset = local_idx * WAVE_BLOCK_M; - auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; - #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - // Swizzle or padding into the correct address - uint8_t* smem_ptr = nullptr; - if constexpr (kSwizzleDMode > 0) { - // Calculate the swizzling atom offset and in-atom offset - constexpr uint32_t kNumBankGroupBytes = 16; - auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); - - // Calculate the index of the bank group to be written in the atom - auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); - - // Reshape the atom in another view and swizzle - // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` - // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` - constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; - auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); - auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); - col ^= row % (kSwizzleDMode / 16); - - // Add back into the base pointer - // NOTES: think twice before modifying this, as changes may affect the number of instructions - smem_ptr = reinterpret_cast(smem_d) + // Base pointer - warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset - m_offset * kSwizzleDMode + // Wave offset - atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) - row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset - } else { - // No swizzling, just padding - // NOTES: padding must be zero for BF16 output - DG_STATIC_ASSERT(BLOCK_N_PADDING == 0, "Padding must be zero for BF16 output"); - smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8); - } - - // NOTES: only 16 lanes' addresses are used - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), - __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), - smem_ptr - ); - } - } - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Use TMA store to write back to global memory - // TODO: compatible with FP32 output - DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); - if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { - auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; - auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, - n_block_idx * BLOCK_N + in_block_n_offset, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - cute::tma_store_arrive(); - } - __syncwarp(); - } - } -#else - if (blockIdx.x == 0 and threadIdx.x == 0) - DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); -#endif -} - -}; // namespace deep_gemm - -#pragma clang diagnostic pop \ No newline at end of file diff --git a/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh deleted file mode 100644 index 7b7e3d313..000000000 --- a/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh +++ /dev/null @@ -1,363 +0,0 @@ -#pragma once - -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wunknown-attributes" - -#include -#include - -#include -#include -#include - -#include "mma_utils.cuh" -#include "scheduler.cuh" -#include "tma_utils.cuh" -#include "utils.cuh" - -namespace deep_gemm { - -template -__global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) -fp8_wgrad_gemm_kernel(uint32_t shape_k, - const __grid_constant__ CUtensorMap tensor_map_a, - const __grid_constant__ CUtensorMap tensor_map_b, - const __grid_constant__ CUtensorMap tensor_map_scales_a, - const __grid_constant__ CUtensorMap tensor_map_scales_b, - const __grid_constant__ CUtensorMap tensor_map_d) { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || defined(__CLION_IDE__) - // Scaling checks - DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - - // Types - using WGMMA = typename FP8MMASelector::type; - using Barrier = cutlass::arch::ClusterTransactionBarrier; - DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); - - // Shared memory - static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); - static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); - static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); - static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); - static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = BLOCK_N * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE = ceil_div(SMEM_SCALES_B_SIZE_PER_STAGE, 128U) * 128U; - - // Configs - constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; - constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); - constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; - - const uint32_t shape_k_scales = ceil_div(shape_k, BLOCK_K); - const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); - const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_id(); - - // Prefetch TMA descriptors at the very beginning - if (threadIdx.x == kNumMathThreads) { - // NOTES: `reinterpret_cast` must be here, or NVRTC will fail - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_b)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); - } - __syncwarp(); - - // Align to 1024 bytes for swizzle-128B - extern __shared__ __align__(1024) uint8_t smem_buffer[]; - DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); - - // Data on shared memory - auto smem_d = reinterpret_cast(smem_buffer); - __nv_fp8_e4m3* smem_a[kNumStages]; - __nv_fp8_e4m3* smem_b[kNumStages]; - float* smem_scales_a[kNumStages]; - float* smem_scales_b[kNumStages]; - - // TMA Barrier for both divisible and non-divisible cases - Barrier* full_barriers[kNumStages + 1]; - Barrier* empty_barriers[kNumStages + 1]; - - // Fill shared memory pointers - #pragma unroll - for (int i = 0; i < kNumStages; ++ i) { - smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); - smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); - smem_scales_a[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) - + i * SMEM_SCALES_A_SIZE_PER_STAGE); - smem_scales_b[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE) - + i * ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE); - } - - // Fill barriers - DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers"); - auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages - * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE)); - #pragma unroll - for (int i = 0; i < kNumStages + 1; ++ i) { - full_barriers[i] = barrier_start_ptr + i; - empty_barriers[i] = barrier_start_ptr + kNumStages + 1 + i; - } - - // Initialize barriers - DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast"); - if (threadIdx.x == kNumMathThreads) { - // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, - // even with TMA multicast disabled, we want to make the behavior aligned - #pragma unroll - for (int i = 0; i < kNumStages; ++ i) { - full_barriers[i]->init(1); - empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); - } - full_barriers[kNumStages]->init(1); - empty_barriers[kNumStages]->init(1); - - // Make initialized barrier visible in async proxy - cutlass::arch::fence_view_async_shared(); - (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); - } - - // Synchronize all threads to make barrier visible in normal memory model - (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); - - // For pipeline unrolling - struct DivisibleK {}; - struct NotDivisibleK {}; - auto launch_k_iterations = [&](const auto& func) { - if constexpr (kNumLastStages == 0) { - for (int k_iter = 0; k_iter < num_iterations; ++ k_iter) - func(k_iter, DivisibleK{}); - } else { - for (int k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) - func(k_iter, DivisibleK{}); - func(num_iterations - 1, NotDivisibleK{}); - } - }; - - // Register reconfigurations - constexpr int kNumTMARegisters = 40; - constexpr int kNumMathRegisters = 232; - - // Block scheduler - uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(SHAPE_M); - - if (threadIdx.x >= kNumMathThreads) { - // TMA warp-group for loading data - cutlass::arch::warpgroup_reg_dealloc(); - - // NOTES: only one thread (or warp) will be used - if (threadIdx.x == kNumMathThreads) { - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](int k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - - // Assign TMA multicast number into A and B - // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. - const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); - const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait consumer release - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); - - // Issue TMA A - auto& full_barrier = *full_barriers[s]; - int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, m_block_idx * BLOCK_M, num_tma_multicast_a); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - k_idx / BLOCK_K, num_tma_multicast_a); - - // Issue TMA B - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, n_block_idx * BLOCK_N, num_tma_multicast_b); - tma_copy(&tensor_map_scales_b, reinterpret_cast(&full_barrier), - smem_scales_b[s], n_block_idx * BLOCK_N, k_idx / BLOCK_K, num_tma_multicast_b); - - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE); - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); - full_barriers[s]->arrive(); - } - }); - - // Issue TMA D - empty_barriers[kNumStages]->wait((scheduler.current_iter + 1) & 1); - auto& full_barrier = *full_barriers[kNumStages]; - tma_copy(&tensor_map_d, reinterpret_cast(&full_barrier), - smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M, 1); - full_barrier.arrive_and_expect_tx(SMEM_D_SIZE); - } - - // To safely deconstruct distributed shared barriers, we need another round of empty waits - if constexpr (kNumTMAMulticast > 1) { - #pragma unroll - for (uint32_t s = 0; s < kNumStages; ++ s) - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); - } - } - } else { - // Math warp-groups for WGMMA - cutlass::arch::warpgroup_reg_alloc(); - - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); - const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4; - const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8; - - // Empty barrier arrival - auto empty_barrier_arrive = [&](int s) { - if constexpr (kNumTMAMulticast == 1) { - lane_idx == 0 ? empty_barriers[s]->arrive() : void(); - } else { - auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); - lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); - } - }; - - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - // Decide the number of scales B to load - DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Accumulation for WGMMA or CUDA promotion - constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M); - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; - float2 scales_b[WGMMA::kNumAccum / 4]; - - // Launch MMAs - launch_k_iterations([&](int k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - - #pragma unroll - for (int s = 0; s < kNumInnerStages; ++ s) { - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); - - #pragma unroll - for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { - auto m_offset = local_idx * WAVE_BLOCK_M; - - // Read A scales - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); - auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); - - // Commit WGMMA instructions - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - - // Read B scales at the first warpgroup wave - if (local_idx == 0) { - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) - scales_b[i] = ld_shared(reinterpret_cast(smem_scales_b[s] + i * 8 + col_idx * 2)); - __syncwarp(); - } - - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival at the last warpgroup wave - if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) - empty_barrier_arrive(s); - - // Promote with scales - auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - const float &scale_b_0 = scales_b[i].x; - const float &scale_b_1 = scales_b[i].y; - shifted_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0]; - shifted_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1]; - shifted_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2]; - shifted_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3]; - } - } - } - - // Wait last TMA store to be finished - if (k_iter == 0 and scheduler.current_iter > 0) { - if (threadIdx.x == 0) { - cute::tma_store_wait<0>(); - empty_barriers[kNumStages]->arrive(); - } - __syncwarp(); - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); - empty_barrier_arrive(s); - } - }); - - // Wait TMA D arrivals - full_barriers[kNumStages]->wait(scheduler.current_iter & 1); - - // Accumulate to D shared memory - #pragma unroll - for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { - auto m_offset = local_idx * WAVE_BLOCK_M; - auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; - auto smem_d_0 = reinterpret_cast(smem_d + (m_offset + r_0) * BLOCK_N + col_idx * 2); - auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + r_1) * BLOCK_N + col_idx * 2); - #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - float2 d_0 = ld_shared(smem_d_0 + i * 4); - st_shared(smem_d_0 + i * 4, {d_0.x + shifted_accum[i * 4 + 0], d_0.y + shifted_accum[i * 4 + 1]}); - float2 d_1 = ld_shared(smem_d_1 + i * 4); - st_shared(smem_d_1 + i * 4, {d_1.x + shifted_accum[i * 4 + 2], d_1.y + shifted_accum[i * 4 + 3]}); - } - } - - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Use TMA store to write back to global memory - if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M); - cute::tma_store_arrive(); - } - __syncwarp(); - } - } -#else - if (blockIdx.x == 0 and threadIdx.x == 0) - DG_DEVICE_ASSERT(false && "This kernel only support sm_90a"); -#endif -} - -}; // namespace deep_gemm - -#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh new file mode 100644 index 000000000..46a668d56 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -0,0 +1,497 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_bf16_gemm_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_c, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M; + constexpr uint32_t kNumTMAStoreStages = 2; + DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size + // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` + constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (threadIdx.x == 0) { + // NOTES: `reinterpret_cast` must be here, or NVRTC will fail + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + if constexpr (kWithAccumulation) + cute::prefetch_tma_descriptor(&tensor_map_c); + } + + // Data on shared memory (layout as ordered below) + cd_dtype_t* smem_cd[kNumTMAStoreStages]; + cutlass::bfloat16_t* smem_a[kNumStages]; + cutlass::bfloat16_t* smem_b[kNumStages]; + + // Fill D/A/B pointers + #pragma unroll + for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i) + smem_cd[i] = reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + } + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (threadIdx.x == 0) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive only at the leader CTA + full_barriers[i]->init(kNumMulticast); + // Arrive at all CTAs + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + cutlass::arch::fence_barrier_init(); + } else if (threadIdx.x >= 32 and threadIdx.x < 64) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + + // For pipeline unrolling + struct DivisibleK {}; + struct NotDivisibleK {}; + uint32_t phase = 0; + auto launch_k_iterations = [&](const auto& func) { + const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K); + const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages; + + // TODO: refactor here + if (num_last_stages == 0) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1) + func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages); + } else { + for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1) + func(k_iter, DivisibleK{}, false, num_last_stages); + func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1; + } + }; + + auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) { + DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2, + "Too many epilogue stages, please modify the Python heuristic as well"); + accum_stage_idx == 0 ? func(0) : func(1); + }; + + // Dispatch warps into different roles + if (warp_idx == 0) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_block_idx = k_iter * kNumStages + s; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + if (cute::elect_one_sync()) { + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx); + } + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + if (is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast); + if (not is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive(0u); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait(phase ^ 1); + if (is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive(); + if (not is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive(0u); + } + }); + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); + auto instr_desc = cute::UMMA::make_instr_desc(); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { + // Wait tensor memory empty barrier arrival + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[s])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait TMA arrival + full_barriers[s]->wait(phase); + tcgen05_after_thread_sync(); + + // Let tensor cores relax for lower possibility of frequency drop + DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control"); + if constexpr (kTensorCoreUtilControl < 100) { + constexpr static uint64_t kNumUMMACycles = (2ull * BLOCK_M * BLOCK_N * BLOCK_K) / 8192ull; + constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl; + const auto& start_clock = clock64(); + if (cute::elect_one_sync()) + while (clock64() - start_clock < kNumDummyCycles) {} + __syncwarp(); + } + + // Issue UMMA in the leader CTA + using cute_mma_t = cute::conditional_t, + cute::SM100_MMA_F16BF16_2x1SM_SS>; + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K); + cute_mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_iter > 0 or s > 0 or k > 0, + runtime_instr_desc); + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait(phase); + empty_barrier_arrive(s, false); + } + }); + }); + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32) { + // Epilogue warp groups + const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads; + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Flush TMA stores + // NOTES: for the first store, we have to flush all previous TMA, + // as we don't share pipeline stages between two blocks + if (epilogue_thread_idx == 0) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s) { + // Wait shared memory to be released + const uint32_t iter_idx = w * kNumStores + s; + if (iter_idx >= kNumTMAStoreStages) { + if (epilogue_thread_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + } + + // The pipeline stage + const auto tma_stage_idx = iter_idx % kNumTMAStoreStages; + const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + __syncwarp(); + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + if (epilogue_thread_idx == 0) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + cute::tma_store_arrive(); + } + } + } + }); + } + + // Flush all stages in the pipeline to make TMA stores visible to the next kernel + if (epilogue_thread_idx == 0) + cute::tma_store_wait<0>(); + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is waiting TMA store + if (epilogue_warp_idx == 1) + Allocator().free(0, kNumTmemCols); + } + + // To safely deconstruct all barriers, we need a cluster sync + // TODO: optimize it by another round of barrier waits + if constexpr (kNumMulticast > 1) + cute::cluster_sync(); +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh new file mode 100644 index 000000000..03c44cd6e --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -0,0 +1,602 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp8_gemm_1d1d_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_c, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M; + constexpr uint32_t kNumTMAStoreStages = 2; + constexpr uint32_t kNumSFStagesPerLoad = sizeof(uint32_t) / sizeof(cutlass::float_ue8m0_t); + constexpr uint32_t kNumUTCCPAlignedElems = 128; + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const uint32_t shape_sf_k = ceil_div(shape_k, BLOCK_K * kNumSFStagesPerLoad); + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); + constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size + // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N + kNumSFATmemCols + kNumSFBTmemCols) > 512 ? 1 : 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + + // Prefetch TMA descriptors at the very beginning + if (threadIdx.x == 0) { + // NOTES: `reinterpret_cast` must be here, or NVRTC will fail + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_d); + if constexpr (kWithAccumulation) + cute::prefetch_tma_descriptor(&tensor_map_c); + } + + // Data on shared memory (layout as ordered below) + cd_dtype_t* smem_cd[kNumTMAStoreStages]; + cutlass::float_e4m3_t* smem_a[kNumStages]; + cutlass::float_e4m3_t* smem_b[kNumStages]; + uint32_t* smem_sfa[kNumStages]; + uint32_t* smem_sfb[kNumStages]; + + // Fill D/A/B pointers + #pragma unroll + for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i) + smem_cd[i] = reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + } + + // Fill SFA/SFB + auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_sfa[i] = reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + smem_sfb[i] = reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + } + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (threadIdx.x == 0) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + // Arrive only at the leader CTA + with_sf_full_barriers[i]->init(kNumMulticast * 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + cutlass::arch::fence_barrier_init(); + } else if (threadIdx.x >= 32 and threadIdx.x < 64) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + + // For pipeline unrolling + struct DivisibleK {}; + struct NotDivisibleK {}; + uint32_t phase = 0; + auto launch_k_iterations = [&](const auto& func) { + const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K); + const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages; + + // TODO: refactor here + if (num_last_stages == 0) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1) + func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages); + } else { + for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1) + func(k_iter, DivisibleK{}, false, num_last_stages); + func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1; + } + }; + + auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) { + DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2, + "Too many epilogue stages, please modify the Python heuristic as well"); + accum_stage_idx == 0 ? func(0) : func(1); + }; + + // Dispatch warps into different roles + if (warp_idx == 0) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_block_idx = k_iter * kNumStages + s; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + if (cute::elect_one_sync()) { + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx); + } + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + + // Issue SFA and SFB TMAs at certain stages + // No swizzling, so one TMA for one SF is enough + const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad; + if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) { + tma_copy(&tensor_map_sfa, full_barriers[s], smem_sfa[s], m_block_idx * BLOCK_M, + scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad))); + tma_copy(&tensor_map_sfb, full_barriers[s], smem_sfb[s], n_block_idx * BLOCK_N, + scheduler.template get_global_idx(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx)); + num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t); + } + + // Arrive at full barriers + if (cute::elect_one_sync()) + full_barriers[s]->arrive_and_expect_tx(num_arrival_bytes); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait(phase ^ 1); + if (cute::elect_one_sync()) + full_barriers[s]->arrive(); + } + }); + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { + // Wait tensor memory empty barrier arrival + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[s])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait TMA and SF-transpose arrival + with_sf_full_barriers[s]->wait(phase); + tcgen05_after_thread_sync(); + + // Do SF copy at certain stages + // NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves + const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad; + if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) { + using cute_utccp_t = cute::conditional_t; + + // SFA and SFB copy + // TODO: process shared memory descriptor by addition + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[s] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[s] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + } + __syncwarp(); + + // Issue UMMA in the leader CTA + using cute_mma_t = cute::conditional_t, + cute::SM100_MMA_MXF8F6F4_2x1SM_SS>; + const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K); + cute_mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_iter > 0 or s > 0 or k > 0, + runtime_instr_desc, + kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32), + kTmemStartColOfSFB); + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + with_sf_full_barriers[s]->wait(phase); + empty_barrier_arrive(s, false); + } + }); + }); + } + } else if (warp_idx == 2) { + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait TMA arrival + full_barriers[s]->wait(phase); + + // Transpose for UTCCP at certain stages + const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad; + if (sf_stage_in_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfa[s] + i * kNumUTCCPAlignedElems); + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfb[s] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + + // Arrive + with_sf_full_barriers[s]->arrive(0u); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait(phase); + with_sf_full_barriers[s]->arrive(0u); + } + }); + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32) { + // Epilogue warp groups + const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads; + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Flush TMA stores + // NOTES: for the first store, we have to flush all previous TMA, + // as we don't share pipeline stages between two blocks + if (epilogue_thread_idx == 0) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s) { + // Wait shared memory to be released + const uint32_t iter_idx = w * kNumStores + s; + if (iter_idx >= kNumTMAStoreStages) { + if (epilogue_thread_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + } + + // The pipeline stage + const auto tma_stage_idx = iter_idx % kNumTMAStoreStages; + const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + __syncwarp(); + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + if (epilogue_thread_idx == 0) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + cute::tma_store_arrive(); + } + } + } + }); + } + + // Flush all stages in the pipeline to make TMA stores visible to the next kernel + if (epilogue_thread_idx == 0) + cute::tma_store_wait<0>(); + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is waiting TMA store + if (epilogue_warp_idx == 1) + Allocator().free(0, kNumTmemCols); + } + + // To safely deconstruct all barriers, we need a cluster sync + // TODO: optimize it by another round of barrier waits + if constexpr (kNumMulticast > 1) + cute::cluster_sync(); +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh new file mode 100644 index 000000000..e04db3ca0 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh @@ -0,0 +1,537 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M; + constexpr uint32_t kNumTMAStoreStages = 2; + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_M == kNumEpilogueThreads, "Invalid block M"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const auto shape_k_scales = ceil_div(shape_k, BLOCK_K); + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + + // Share memory sizes + // NOTES: do not use `LOAD_BLOCK_M` for SFA, as we need full SFA for promotion + constexpr bool kMustUseUniformedSFB = (BLOCK_K % BLOCK_N == 0); + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // Must have 2 epilogue stages + constexpr uint32_t kNumEpilogueStages = 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (threadIdx.x == 0) { + // NOTES: `reinterpret_cast` must be here, or NVRTC will fail + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + } + + // Data on shared memory (layout as ordered below) + cd_dtype_t* smem_cd[kNumTMAStoreStages]; + cutlass::float_e4m3_t* smem_a[kNumStages]; + cutlass::float_e4m3_t* smem_b[kNumStages]; + float* smem_sfa[kNumStages]; + + // Fill D/A/B pointers + #pragma unroll + for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i) + smem_cd[i] = reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + } + + // Fill SFA/SFB + auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) + smem_sfa[i] = reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + + kNumStages * SMEM_SFA_SIZE_PER_STAGE); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 2 + kNumEpilogueStages * 2); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (threadIdx.x == 0) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(kNumMulticast); + empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads / 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + cutlass::arch::fence_barrier_init(); + } else if (threadIdx.x >= 32 and threadIdx.x < 64) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK {}; + struct NotDivisibleK {}; + const uint32_t num_iterations = ceil_div(shape_k, kNumStages * BLOCK_K); + auto launch_k_iterations = [=](const auto& func) { + if constexpr (kNumLastStages == 0) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) + func(k_iter, DivisibleK{}); + } else { + for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}); + func(num_iterations - 1, NotDivisibleK{}); + } + }; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + + // Register configurations + constexpr uint32_t kNumNonEpilogueRegisters = 64; + constexpr uint32_t kNumEpilogueRegisters = 216; + DG_STATIC_ASSERT(kNumNonEpilogueRegisters * kNumNonEpilogueThreads + kNumEpilogueRegisters * kNumEpilogueThreads <= 65535, "Too many registers"); + + // Dispatch warps into different roles + if (warp_idx == 0) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto type) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.get_global_idx<(kMajorB == cute::UMMA::Major::K)>( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_block_idx = k_iter * kNumStages + s; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_b_idx = scheduler.get_global_idx<(kMajorB == cute::UMMA::Major::MN)>( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + if (cute::elect_one_sync()) { + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], k_idx, m_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx); + + // Issue SFA TMA + tma_copy( + &tensor_map_sfa, full_barriers[s], + smem_sfa[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(shape_k_scales, 1, k_block_idx)); + } + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE; + if (is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast); + if (not is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive(0u); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + if (is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive(); + if (not is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive(0u); + } + }); + } + } else if (warp_idx == 1 and is_leader_cta) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + auto instr_desc = cute::UMMA::make_instr_desc(); + auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Launch MMAs + launch_k_iterations([&](uint32_t k_iter, auto type) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) { + // Wait TMA full + auto iter_idx = scheduler.current_iter * num_iterations + k_iter; + full_barriers[s]->wait(iter_idx & 1); + + // Wait tensor memory empty + auto accum_stage_idx = (iter_idx * kNumStages + s) % kNumEpilogueStages; + auto accum_stage_phase = ((iter_idx * kNumStages + s) / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_stage_phase ^ 1); + + // Issue UMMA in the leader CTA + if (s < kNumInnerStages) { + using cute_mma_t = cute::conditional_t; + tcgen05_after_thread_sync(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + auto b_desc = make_umma_desc(smem_b[s], 0, k * UMMA_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + auto a_desc = make_umma_desc(smem_a[s], w * LAYOUT_AD_M, k * UMMA_K); + cute_mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k > 0, + runtime_instr_desc); + } + } + tcgen05_before_thread_sync(); + } + + // Commit to the TMA empty and tensor memory full barrier + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + } + }); + } + } else if (warp_idx < kNumNonEpilogueThreads / 32) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx >= kNumNonEpilogueThreads / 32) { + // Adjust registers + cutlass::arch::warpgroup_reg_alloc(); + + // Epilogue warp groups + const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads; + const auto epilogue_thread_idx_in_warpgroup = epilogue_thread_idx % 128; + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + const auto epilogue_warpgroup_idx = epilogue_thread_idx / 128; + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + constexpr uint32_t kNumElemsPerLDTM = 16; + DG_STATIC_ASSERT(kNumElemsPerLDTM == 16 and BLOCK_N % kNumElemsPerLDTM == 0 and BLOCK_K % kNumElemsPerLDTM == 0, "Invalid LDTM width"); + + // SFB stuffs + uint32_t num_former_iters = BLOCK_N, num_full_iters = BLOCK_N; + if constexpr (not kMustUseUniformedSFB) { + num_former_iters = min(BLOCK_N, BLOCK_K - ((n_block_idx * BLOCK_N) % BLOCK_K)); + num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N); + } + num_former_iters /= kNumElemsPerLDTM, num_full_iters /= kNumElemsPerLDTM; + const auto sfb_offset = scheduler.get_global_idx(ceil_div(shape_n, BLOCK_K), 0, 0, m_block_idx); + const auto sfb_ptr = sfb + (sfb_offset + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales; + + // Launch promotion + float accum[BLOCK_N] = {0}; + launch_k_iterations([&](uint32_t k_iter, auto type) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) { + // Load SFB + float sf_0 = 0, sf_1 = 0; + if (s < kNumInnerStages) { + const auto k_block_idx = k_iter * kNumStages + s; + sf_0 = __ldg(sfb_ptr + k_block_idx); + sf_1 = num_former_iters < num_full_iters ? __ldg(sfb_ptr + k_block_idx + shape_k_scales) : 0; + } + + // Wait UMMA arrival + auto iter_idx = scheduler.current_iter * num_iterations + k_iter; + auto accum_stage_idx = (iter_idx * kNumStages + s) % kNumEpilogueStages; + auto accum_stage_phase = ((iter_idx * kNumStages + s) / kNumEpilogueStages) & 1; + tmem_full_barriers[accum_stage_idx]->wait(accum_stage_phase); + tcgen05_after_thread_sync(); + + // Commit to the TMA empty barrier for all CTAs after loading SFA + float sfa = s < kNumInnerStages ? ld_shared(smem_sfa[s] + epilogue_thread_idx) : 0; + sf_0 *= sfa, sf_1 *= sfa; + __syncwarp(); + if (lane_idx < kNumMulticast) + empty_barriers[s]->arrive(lane_idx); + __syncwarp(); + + // Do promotion like the SM90 kernel + if (s < kNumInnerStages) { + uint32_t values[kNumElemsPerLDTM]; + #pragma unroll + for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerLDTM; ++ i) { + // Load from tensor memory + cute::SM100_TMEM_LOAD_32dp32b16x::copy( + accum_stage_idx * kNumMWaves * BLOCK_N + epilogue_warpgroup_idx * BLOCK_N + i * kNumElemsPerLDTM, + values[ 0], values[ 1], values[ 2], values[ 3], + values[ 4], values[ 5], values[ 6], values[ 7], + values[ 8], values[ 9], values[10], values[11], + values[12], values[13], values[14], values[15]); + cutlass::arch::fence_view_async_tmem_load(); + + // Promote + const auto sf = (kMustUseUniformedSFB or i < num_former_iters) ? sf_0 : sf_1; + #pragma unroll + for (uint32_t j = 0; j < kNumElemsPerLDTM; ++ j) + accum[i * kNumElemsPerLDTM + j] += *reinterpret_cast(&values[j]) * sf; + } + } + + // Commit to the tensor memory empty barrier (only at the leader CTA) + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + }); + + // Flush TMA stores + // NOTES: for the first store, we have to flush all previous TMA, + // as we don't share pipeline stages between two blocks + if (epilogue_thread_idx_in_warpgroup == 0) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync(); + + // Write shared memory + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Epilogue store and addition + // Issue every swizzled atom and pipeline: store shared, add C, and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s) { + // Wait shared memory to be released + if (s >= kNumTMAStoreStages) { + if (epilogue_thread_idx_in_warpgroup == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync(); + } + + // The pipeline stage + const auto tma_stage_idx = s % kNumTMAStoreStages; + const auto m_idx = scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(shape_m, BLOCK_M, m_block_idx); + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + const auto local_smem_cd = smem_cd[tma_stage_idx] + epilogue_warpgroup_idx * STORE_BLOCK_M * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + // NOTES: if you want to do accumulation, please notice that you need two accumulation barriers + const auto offset = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + st_shared(smem_ptr, + *reinterpret_cast(&accum[offset + 0]), + *reinterpret_cast(&accum[offset + 1]), + *reinterpret_cast(&accum[offset + 2]), + *reinterpret_cast(&accum[offset + 3])); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + st_shared(smem_ptr, + cast_into_bf16_and_pack(accum[offset + 0], accum[offset + 1]), + cast_into_bf16_and_pack(accum[offset + 2], accum[offset + 3]), + cast_into_bf16_and_pack(accum[offset + 4], accum[offset + 5]), + cast_into_bf16_and_pack(accum[offset + 6], accum[offset + 7])); + } + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync(); + if (epilogue_thread_idx_in_warpgroup == 0) { + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_d, local_smem_cd, + n_idx, m_idx + epilogue_warpgroup_idx * STORE_BLOCK_M); + cute::tma_store_arrive(); + } + } + } + + // Flush all stages in the pipeline to make TMA stores visible to the next kernel + if (epilogue_thread_idx_in_warpgroup == 0) + cute::tma_store_wait<0>(); + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is waiting TMA store + if (epilogue_warp_idx == 1) + Allocator().free(0, kNumTmemCols); + } + + // To safely deconstruct all barriers, we need a cluster sync + // TODO: optimize it by another round of barrier waits + if constexpr (kNumMulticast > 1) + cute::cluster_sync(); +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh new file mode 100644 index 000000000..23045e1fe --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -0,0 +1,343 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_bf16_gemm_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Types + using WGMMA = typename BF16MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (threadIdx.x == kNumMathThreads) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + __nv_bfloat16* smem_a[kNumStages]; + __nv_bfloat16* smem_b[kNumStages]; + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages]; + Barrier* empty_barriers[kNumStages]; + + // Fill shared memory pointers + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast<__nv_bfloat16*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_bfloat16*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + } + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + i; + } + + // Initialize barriers + if (threadIdx.x == kNumMathThreads) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + struct DivisibleK {}; + struct NotDivisibleK {}; + auto launch_k_iterations = [=](const auto& func) { + if constexpr (kNumLastStages == 0) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) + func(k_iter, DivisibleK{}); + } else { + for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}); + func(num_iterations - 1, NotDivisibleK{}); + } + }; + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 48; + constexpr uint32_t kNumMathRegisters = 224; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto divisible_type) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[s]; + uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a); + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Launch MMAs + launch_k_iterations([&](uint32_t k_iter, auto divisible_type) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + + // TODO: remove some useless computation for unaligned Ms + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, shifted_accum, 1); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(s); + } + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type"); + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling, just padding + // TODO: support more cases + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + // TODO: compatible with FP32 output + constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; + DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh new file mode 100644 index 000000000..28b5399a4 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -0,0 +1,3 @@ +#pragma once + +// TODO: add implement \ No newline at end of file diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh new file mode 100644 index 000000000..e6f71b519 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -0,0 +1,842 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) { + if (num_former_iters == kNumFormerIters) { + inner_launch_k_iterations(func, cute::Int{}); + return; + } + + if constexpr (kNumFormerIters + kGap <= kEnd) + outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); +} + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); + const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (threadIdx.x == kNumMathThreads) { + // NOTES: `reinterpret_cast` must be here, or NVRTC will fail + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + __nv_fp8_e4m3* smem_a[kNumStages]; + __nv_fp8_e4m3* smem_b[kNumStages]; + float* smem_sfa[kNumStages]; + float* smem_sfb; + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages]; + Barrier* empty_barriers[kNumStages]; + + // Fill shared memory pointers + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + smem_sfa[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SFA_SIZE_PER_STAGE); + } + smem_sfb = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE)); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + i; + } + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (threadIdx.x == kNumMathThreads) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK {}; + struct NotDivisibleK {}; + struct SkipComputation {}; + struct NotSkipComputation {}; + auto launch_k_iterations = [=](const auto& func, bool skip_computation, uint32_t num_former_iters) { + constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; + + // NOTES: for too-many branches (> 5), we disable this optimization + // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value + outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) { + if (skip_computation) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) + func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type); + } else if (shape_k % kFullKOfAllStages == 0) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) + func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); + } else { + for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); + func(num_iterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type); + } + }, func, kShouldOptimize ? num_former_iters : 0); + }; + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + + // Issue TMA A + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[s]; + uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a); + tma_copy(&tensor_map_sfa, reinterpret_cast(&full_barrier), + smem_sfa[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(shape_k_scales, 1, k_idx / BLOCK_K), + num_tma_multicast_a); + + // Issue TMA B + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }, false, 0); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2); + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(ceil_div(shape_n, BLOCK_K), 0, 0, m_block_idx); + auto local_sfb = sfb + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) + st_shared(smem_sfb + i, __ldg(local_sfb + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + // Launch MMAs + launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) { + constexpr bool kSkipComputation = cute::is_same_v; + constexpr bool kHasDivisibleStages = cute::is_same_v; + constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : (kHasDivisibleStages ? kNumStages : kNumLastStages); + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_sfb + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_sfb + k_iter * kNumStages + s + shape_k_scales); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + + // TODO: remove some useless computation for unaligned Ms + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_sfa[s] + r_0 + m_offset); + auto scale_a_1 = ld_shared(smem_sfa[s] + r_1 + m_offset); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(s); + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } + } + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters); + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling, just padding + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + // TODO: compatible with FP32 output + constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; + DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_fp8_gemm_per_tensor_1d2d_impl(float* scales_b, int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); + const uint32_t& smem_sfb_size = align(1 * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (threadIdx.x == kNumMathThreads) { + // NOTES: `reinterpret_cast` must be here, or NVRTC will fail + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + + // `tensor_map_d` is only used in swizzling mode + // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode + cute::prefetch_tma_descriptor(&tensor_map_d); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + __nv_fp8_e4m3* smem_a[kNumStages]; + __nv_fp8_e4m3* smem_b[kNumStages]; + float* smem_scales_b; + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages]; + Barrier* empty_barriers[kNumStages]; + + // Fill shared memory pointers + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + } + smem_scales_b = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_scales_b) + smem_sfb_size); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + i; + } + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (threadIdx.x == kNumMathThreads) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK {}; + struct NotDivisibleK {}; + struct SkipComputation {}; + struct NotSkipComputation {}; + auto launch_k_iterations = [=](const auto& func, bool skip_computation, uint32_t num_former_iters) { + constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; + + // NOTES: for too-many branches (> 5), we disable this optimization + // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value + outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) { + if (skip_computation) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) + func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type); + } else if (SHAPE_K % kFullKOfAllStages == 0) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) + func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); + } else { + for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); + func(num_iterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type); + } + }, func, kShouldOptimize ? num_former_iters : 0); + }; + + // // For pipeline unrolling + // struct DivisibleK {}; + // struct NotDivisibleK {}; + // auto launch_k_iterations = [](const auto& func) { + // if constexpr (SHAPE_K % kFullKOfAllStages == 0) { + // for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) + // func(k_iter, DivisibleK{}); + // } else { + // for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) + // func(k_iter, DivisibleK{}); + // func(kNumIterations - 1, NotDivisibleK{}); + // } + // }; + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // launch_k_iterations([&](uint32_t k_iter, auto divisible_type) { + launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + + // Issue TMA A + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[s]; + uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a); + + // Issue TMA B + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }, false, 0); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_scales_b = 1 * (num_former_iters >= num_full_iters ? 1 : 2); + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + // if (threadIdx.x == 32) { + // auto local_scales_b = scales_b + scheduler.curr_group_idx; + // st_shared(smem_scales_b, __ldg(local_scales_b)); + // } + int current_group_idx = scheduler.current_group_idx; + if constexpr (kGemmType == GemmType::MGroupedContiguous) { + current_group_idx = __ldg(scheduler.grouped_layout + m_block_idx * BLOCK_M); + } + auto local_scales_b = scales_b + current_group_idx; + float scale_b_0 = __ldg(local_scales_b); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + // Launch MMAs + launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) { + constexpr bool kSkipComputation = std::is_same_v; + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : + (kHasDivisibleStages ? kNumStages : kNumLastStages); + + // float scale_b_0 = ld_shared(smem_scales_b); + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + + // TODO: remove some useless computation for unaligned Ms + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(s); + + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + shifted_accum[i * 4 + 0] += scale_b_0 * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += scale_b_0 * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += scale_b_0 * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += scale_b_0 * accum[i * 4 + 3]; + } + } + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters); + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling, just padding + // NOTES: padding must be zero for BF16 output + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + // TODO: compatible with FP32 output + constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; + DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh b/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh new file mode 100644 index 000000000..bea700027 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh @@ -0,0 +1,176 @@ +#pragma once + +#include + +namespace deep_gemm { + +template +__global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { + typedef typename Vectorized::vec_t in_vec_t; + constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float); + constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec; + + // Shapes and strides + extern __shared__ float smem_buffer[]; + constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(float)); + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + + // Shift into the block + sf = sf + static_cast(blockIdx.y) * mn * SF_K; + out = out + static_cast(blockIdx.y) * tma_aligned_mn * SF_K; + const auto& local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + + // Load + for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) { + auto in_vec = __ldg(local_sf + i); + const auto& in_values = reinterpret_cast(&in_vec); + + const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec; + #pragma unroll + for (uint32_t j = 0; j < kNumElemsPerVec; ++ j) + smem_buffer[row * PADDED_SF_K + col + j] = in_values[j]; + } + __syncthreads(); + + // Store + #pragma unroll + for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) { + const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn; + const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; + out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); + } +} + +// NOTES: the two kernels below always pack the K dimension + +template +__global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { + extern __shared__ uint32_t smem_buffer[]; + + // Shapes and strides + constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u); + constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(int)); + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + + // Shift into the group + sf = sf + static_cast(blockIdx.y) * mn * SF_K; + out = out + static_cast(blockIdx.y) * tma_aligned_mn * kNumPackedSFK; + + // Load FP32 SFs + DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size"); + const auto local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + const auto num_values = in_block_mn * SF_K; + const auto num_uint4 = num_values / 4; + #pragma unroll + for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) { + const auto& [x, y, z, w] = __ldg(reinterpret_cast(local_sf) + i); + st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); + } + + // Fill unaligned values as well + if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values) + st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx)); + __syncthreads(); + + // Pack into UE8M0 and store + #pragma unroll + for (uint32_t i = threadIdx.x; i < (kNumPackedSFK * BLOCK_MN); i += kNumThreads) { + const auto sf_k_pack_idx = i / BLOCK_MN, mn_idx = i % BLOCK_MN; + + // Load shared memory + uint32_t values[4]; + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + const auto sf_k_idx = sf_k_pack_idx * 4 + j; + values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; + } + + // Pack and store + uint32_t packed = 0; + packed |= (values[0] >> 23u); + packed |= (values[1] >> 15u); + packed |= (values[2] >> 7u); + packed |= (values[3] << 1u); + if (const auto global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; global_mn_idx < mn) + out[sf_k_pack_idx * tma_aligned_mn + global_mn_idx] = packed; + } +} + +template +__global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, + const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) { + // Always packing the K dimension + // NOTES: should also assert `mn % 4 == 0` at launch + DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)"); + DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_PACKED_SF_K == kNumThreads / 32, "Invalid block sizes"); + + // Shapes and strides + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto in_block_mn_uint4 = in_block_mn / 4; + const auto in_block_packed_sf_k = min(BLOCK_PACKED_SF_K, packed_sf_k - blockIdx.y * BLOCK_PACKED_SF_K); + + // Shift into the right block along MN + sf += blockIdx.x * BLOCK_MN; + out += blockIdx.x * BLOCK_MN; + + // Each warp is responsible for a packed row + const auto warp_idx = threadIdx.x / 32; + const auto lane_idx = get_lane_idx(); + const auto packed_sf_k_idx = static_cast(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx; + if (warp_idx >= in_block_packed_sf_k) + return; + + // Make an offset on the input + uint32_t input_offset = 0; + if constexpr (kNumGroups > 1) { + // Load each group's size + DG_STATIC_ASSERT(kNumGroups <= 128, "Too many groups"); + uint32_t group_ks[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) { + const auto group_idx = lane_idx * 4 + i; + group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0; + } + __syncwarp(); + + // Make the offset + sf_k = 0; + auto sum_packed_sf_k = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumGroups; ++ i) { + const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4); + sf_k += sf_k_in_group; + sum_packed_sf_k += ceil_div(sf_k_in_group, 4u); + if (packed_sf_k_idx < sum_packed_sf_k) + break; + if (const auto remainder = sf_k_in_group % 4; remainder > 0) + input_offset += 4 - remainder; + } + } + + for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { + // Load + uint4 values[4]; + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + values[j] = make_uint4(0, 0, 0, 0); + if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k) + values[j] = __ldg(reinterpret_cast(sf + sf_k_idx * mn) + mn_idx); + } + + // Pack and store + uint4 packed; + packed.x = (values[0].x >> 23u) | (values[1].x >> 15u) | (values[2].x >> 7u) | (values[3].x << 1u); + packed.y = (values[0].y >> 23u) | (values[1].y >> 15u) | (values[2].y >> 7u) | (values[3].y << 1u); + packed.z = (values[0].z >> 23u) | (values[1].z >> 15u) | (values[2].z >> 7u) | (values[3].z << 1u); + packed.w = (values[0].w >> 23u) | (values[1].w >> 15u) | (values[2].w >> 7u) | (values[3].w << 1u); + reinterpret_cast(out + packed_sf_k_idx * mn)[mn_idx] = packed; + } +} + +} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh deleted file mode 100644 index 85b2ccc0c..000000000 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ /dev/null @@ -1,212 +0,0 @@ -#pragma once - -#ifndef __CUDACC_RTC__ -#include -#endif - -#include -#include - -#include "utils.cuh" - -namespace deep_gemm { - -template -struct SM90_U32x2_STSM_N { - __device__ __forceinline__ static void - copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { - const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; - asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" - :: "l"(smem_dst), "r"(src[0]), "r"(src[1])); - } -}; - -template -struct SM90_U32x4_STSM_N { - __device__ __forceinline__ static void - copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) { - const uint32_t src[4] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1), - *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; - asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" - :: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); - } -}; - -__forceinline__ __device__ void warpgroup_arrive() { - asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); -} - -__forceinline__ __device__ void warpgroup_commit_batch() { - asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); -} - -__forceinline__ __device__ void warpgroup_fence_operand(float& reg) { - asm volatile("" : "+f"(reg) :: "memory"); -} - -__forceinline__ __device__ uint32_t get_lane_id() { - uint32_t lane_id; - asm("mov.u32 %0, %laneid;" : "=r"(lane_id)); - return lane_id; -} - -__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) { - uint32_t ret; - asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); - return ret; -} - -__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) { - int4 ret; - asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr)); - return ret; -} - -__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) { - float ret; - asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); - return ret; -} - -__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) { - float2 ret; - asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr)); - return ret; -} - -__device__ __forceinline__ void st_shared(const float* ptr, float val) { - asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val)); -} - -__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { - asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val)); -} - -__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { - asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y)); -} - -template -__device__ void warpgroup_wait() { - DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); - asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); -} - -union GmmaDescriptor { - __host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {} - - __host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {} - - __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {} - - __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {} - - __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept { - desc_ = t.desc_; - return *this; - } - - __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept { - desc_ = t.desc_; - return *this; - } - - uint64_t desc_; - uint32_t reg32_[2]; - uint16_t reg16_[4]; - - struct { - uint16_t start_address_: 14, : 2; - uint16_t leading_byte_offset_: 14, : 2; - uint16_t stride_byte_offset_: 14, : 2; - uint8_t : 1, base_offset_: 3, : 4; - uint8_t : 6, layout_type_: 2; - } bitfield; - - // Decay to an `uint64_t` - __host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; } -}; - -template -__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type, - int leading_byte_offset = 0, - int stride_byte_offset = 1024) { - GmmaDescriptor desc; - auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); - desc.bitfield.start_address_ = uint_ptr >> 4; - desc.bitfield.layout_type_ = layout_type; - desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; - desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; - desc.bitfield.base_offset_ = 0; - return desc; -} - -template -struct FP8MMA { - - template - __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, std::index_sequence) { - using namespace cute::SM90::GMMA; - MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); - } - - __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - call_fma_impl(desc_a, desc_b, d, scale_d, std::make_index_sequence{}); - } - - static constexpr int M = 64; - static constexpr int N = N_; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -template -struct FP8MMASelector { - - static constexpr auto select_mma() { - using namespace cute::SM90::GMMA; - if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); - } - - static constexpr auto select_type() { - return FP8MMA(); - } - - using type = decltype(select_type()); -}; - -enum class Layout { - RowMajor, - ColMajor -}; - -__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) { - return block_m == 64 ? 1 : 2; -} - -template -__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { - DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); - return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads; -} - -} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/nvrtc_std.cuh b/deep_gemm/include/deep_gemm/nvrtc_std.cuh deleted file mode 100644 index 00ce7341c..000000000 --- a/deep_gemm/include/deep_gemm/nvrtc_std.cuh +++ /dev/null @@ -1,103 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#ifdef __CUDACC_RTC__ - -using int8_t = signed char; -using uint8_t = unsigned char; -using int16_t = signed short; -using uint16_t = unsigned short; -using int32_t = signed int; -using uint32_t = unsigned int; -using int64_t = signed long long; -using uint64_t = unsigned long long; -using cuuint64_t = unsigned long long; - -#ifndef CU_TENSOR_MAP_NUM_QWORDS -#define CU_TENSOR_MAP_NUM_QWORDS 16 - -struct CUtensorMap_st { -#if defined(__cplusplus) && (__cplusplus >= 201103L) - alignas(64) -#elif __STDC_VERSION__ >= 201112L - _Alignas(64) -#endif - cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; -}; - -using CUtensorMap = CUtensorMap_st; -#endif - -namespace std { - -template struct integral_constant { - static constexpr T value = v; - - using value_type = T; - using type = integral_constant; - - __device__ constexpr operator value_type() const noexcept { return value; } - - __device__ constexpr value_type operator()() const noexcept { return value; } -}; - -using false_type = integral_constant; -using true_type = integral_constant; - -template struct is_same : false_type {}; - -template struct is_same : true_type {}; - -template -inline constexpr bool is_same_v = is_same::value; - -namespace index_sequence_impl { - -// Based on https://stackoverflow.com/a/32223343/11717224 -template struct index_sequence { - using type = index_sequence; - using value_type = size_t; - static constexpr size_t size() noexcept { return sizeof...(Ints); } -}; - -template struct _merge_and_renumber; - -template -struct _merge_and_renumber, index_sequence> - : index_sequence {}; - -template -struct make_index_sequence - : _merge_and_renumber::type, - typename make_index_sequence::type> {}; - -template <> struct make_index_sequence<0> : index_sequence<> {}; -template <> struct make_index_sequence<1> : index_sequence<0> {}; - -} // namespace index_sequence_impl - -template -using index_sequence = index_sequence_impl::index_sequence; - -template -using make_index_sequence = index_sequence_impl::make_index_sequence; - -} // namespace std - -#endif diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh deleted file mode 100644 index 81bfeba07..000000000 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ /dev/null @@ -1,163 +0,0 @@ -#pragma once - -#include "utils.cuh" - -namespace deep_gemm { - -enum class GemmType { - Normal, - GroupedContiguous, - GroupedMasked -}; - -#pragma clang diagnostic push -#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" -template -struct Scheduler { - int current_iter = -1; - uint32_t num_aligned_m_blocks; - - // For normal GEMM - // Maybe not used in the masked grouped GEMM - uint32_t num_blocks; - uint32_t num_blocks_in_group; - bool is_peer_cta_alive = true; - - // For grouped GEMM - int* grouped_layout; - - // Only used for masked layout - uint32_t curr_group_idx, curr_cumsum; - - __device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, - int* grouped_layout = nullptr) { - num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M); - if constexpr (kGemmType == GemmType::Normal) { - num_blocks = num_aligned_m_blocks * kNumNBlocks; - } else if (kGemmType == GemmType::GroupedContiguous) { - num_blocks = num_aligned_m_blocks * kNumNBlocks; - this->grouped_layout = grouped_layout; - } else if (kGemmType == GemmType::GroupedMasked) { - curr_group_idx = curr_cumsum = 0; - this->grouped_layout = grouped_layout; - } - } - - // ReSharper disable once CppNotAllPathsReturnValue - __device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { - if constexpr (kGemmType == GemmType::Normal) { - return true; - } else if constexpr (kGemmType == GemmType::GroupedContiguous) { - return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0; - } else if constexpr (kGemmType == GemmType::GroupedMasked) { - return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + curr_group_idx); - } - } - - __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { - if (num_blocks_in_group == 1) - return false; - if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) { - return true; - } else { - DG_STATIC_ASSERT(kGemmType == GemmType::GroupedContiguous, "Invalid Gemm type"); - if constexpr (kIsTMAMulticastOnA) { - return true; - } else { - auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); - auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); - return group_idx == peer_group_idx; - } - } - } - - __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& num_m_blocks, const uint32_t& block_idx, - uint32_t& m_block_idx, uint32_t& n_block_idx) { - DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); - - // Swizzle for better L2 usages - auto primary_num_blocks = kIsTMAMulticastOnA ? kNumNBlocks : num_m_blocks; - auto secondary_num_blocks = kIsTMAMulticastOnA ? num_m_blocks : kNumNBlocks; - auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; - auto group_idx = block_idx / num_blocks_per_group; - auto first_block_idx = group_idx * kNum1DBlocksPerGroup; - auto in_group_idx = block_idx % num_blocks_per_group; - num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); - - // Fix unaligned TMA multicast - if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) { - if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { - num_blocks_in_group = num_blocks_in_group ^ 1; - } else { - in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; - first_block_idx += num_blocks_in_group ^ 1; - num_blocks_in_group = 1; - } - } - - // Convert to final M/N block indices - if constexpr (kIsTMAMulticastOnA) { - m_block_idx = in_group_idx / num_blocks_in_group; - n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; - } else { - m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; - n_block_idx = in_group_idx / num_blocks_in_group; - } - } - - template - __device__ __forceinline__ uint32_t get_global_idx(const uint32_t& shape_dim, const uint32_t& block_size, - const uint32_t& block_idx, const uint32_t& m_block_idx=0) { - if constexpr (kGemmType == GemmType::Normal) { - return block_idx * block_size; - } else if constexpr (kGemmType == GemmType::GroupedContiguous) { - auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M); - return offset * shape_dim + block_idx * block_size; - } else if constexpr (kGemmType == GemmType::GroupedMasked) { - return curr_group_idx * shape_dim + block_idx * block_size; - } - } - - __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { - const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x; - - if constexpr (kGemmType == GemmType::GroupedMasked) { - uint32_t num_m_blocks; - while (true) { - // End of the task - if (curr_group_idx == kNumGroups) - return false; - - // Within the current group - num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); - auto current_m_block_cumsum = curr_cumsum + num_m_blocks; - if (next_block_idx < current_m_block_cumsum * kNumNBlocks) - break; - - // Move to check the next group - curr_group_idx ++, curr_cumsum = current_m_block_cumsum; - } - - get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); - } else { - if (next_block_idx >= num_blocks) - return false; - - // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned - is_peer_cta_alive = kNumNBlocks % kNumTMAMulticast == 0 or // Always aligned on N (constant bypass) - num_aligned_m_blocks % kNumTMAMulticast == 0 or // Always aligned on M (constant bypass) - (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound - get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); - } - return true; - } -}; - -#pragma clang diagnostic pop - -} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/tma_utils.cuh b/deep_gemm/include/deep_gemm/tma_utils.cuh deleted file mode 100644 index 795dca6a2..000000000 --- a/deep_gemm/include/deep_gemm/tma_utils.cuh +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include "utils.cuh" - -namespace deep_gemm { - -// TODO: move this function to other files -__device__ __forceinline__ void -tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, - int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) { - constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); - if (num_tma_multicast == 1) { - cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); - } else if (cute::block_rank_in_cluster() == 0) { - cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1); - } -} - -} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/utils.cuh b/deep_gemm/include/deep_gemm/utils.cuh deleted file mode 100644 index 598a41467..000000000 --- a/deep_gemm/include/deep_gemm/utils.cuh +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#ifdef __CLION_IDE__ - -__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { - asm volatile("trap;"); -} - -#define printf host_device_printf -#endif - -#ifndef DG_DEVICE_ASSERT -#define DG_DEVICE_ASSERT(cond) \ -do { \ - if (not (cond)) { \ - printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ - asm("trap;"); \ - } \ -} while (0) -#endif - -#ifndef DG_STATIC_ASSERT -#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason) -#endif - -template -__device__ __host__ constexpr T ceil_div(T a, T b) { - return (a + b - 1) / b; -} - -template -__device__ __host__ constexpr T constexpr_gcd(T a, T b) { - return b == 0 ? a : constexpr_gcd(b, a % b); -} diff --git a/deep_gemm/jit/__init__.py b/deep_gemm/jit/__init__.py deleted file mode 100644 index 06a519400..000000000 --- a/deep_gemm/jit/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .compiler import get_nvcc_compiler, build, NVCCCompiler, NVRTCCompiler -from .runtime import Runtime diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py deleted file mode 100644 index d3f1f7626..000000000 --- a/deep_gemm/jit/compiler.py +++ /dev/null @@ -1,284 +0,0 @@ -import functools -import hashlib -import os -import re -import subprocess -import time -import uuid -from typing import Any, Dict, List, Tuple, Type - -import cuda.bindings -import cuda.bindings.nvrtc as nvrtc -from torch.utils.cpp_extension import CUDA_HOME - -from . import interleave_ffma -from .runtime import Runtime, RuntimeCache - -runtime_cache = RuntimeCache() - - -def hash_to_hex(s: str) -> str: - md5 = hashlib.md5() - md5.update(s.encode('utf-8')) - return md5.hexdigest()[0:12] - - -@functools.lru_cache(maxsize=None) -def get_jit_include_dir() -> str: - return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include') - - -@functools.lru_cache(maxsize=None) -def get_deep_gemm_version() -> str: - md5 = hashlib.md5() - - # Update include directories - include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm') - assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}' - for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))): - with open(os.path.join(include_dir, filename), 'rb') as f: - md5.update(f.read()) - - # Update `interleave_ffma.py` - with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f: - md5.update(f.read()) - return md5.hexdigest()[0:12] - - -@functools.lru_cache(maxsize=None) -def get_nvcc_compiler() -> Tuple[str, str]: - paths = [] - if os.getenv('DG_JIT_NVCC_COMPILER'): - paths.append(os.getenv('DG_JIT_NVCC_COMPILER')) - paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc')) - - # Try to find the first available NVCC compiler - least_version_required = '12.3' - version_pattern = re.compile(r'release (\d+\.\d+)') - for path in paths: - if os.path.exists(path): - command = [path, '--version'] - result = subprocess.run(command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, text=True) - match = version_pattern.search(result.stdout) - version = match.group(1) - assert match, f'Cannot get the version of NVCC compiler {path}' - assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}' - return path, version - raise RuntimeError('Cannot find any available NVCC compiler') - - -@functools.lru_cache(maxsize=None) -def get_default_user_dir(): - if 'DG_JIT_CACHE_DIR' in os.environ: - path = os.getenv('DG_JIT_CACHE_DIR') - os.makedirs(path, exist_ok=True) - return path - return os.path.join(os.path.expanduser('~'), '.deep_gemm') - - -@functools.lru_cache(maxsize=None) -def get_tmp_dir(): - return os.path.join(get_default_user_dir(), 'tmp') - - -@functools.lru_cache(maxsize=None) -def get_cache_dir(): - return os.path.join(get_default_user_dir(), 'cache') - - -def make_tmp_dir(): - tmp_dir = get_tmp_dir() - os.makedirs(tmp_dir, exist_ok=True) - return tmp_dir - - -def put(path, data): - # Write and do POSIX atomic replace - tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}') - with open(tmp_file_path, 'wb' if isinstance(data, bytes) else 'w') as f: - f.write(data) - os.replace(tmp_file_path, path) - - -class Compiler: - @classmethod - def signature(cls) -> str: - pass - - @staticmethod - def __version__() -> Tuple[int, int]: - pass - - @classmethod - def compile(cls, name: str, code: str, target_path: str) -> None: - pass - - @staticmethod - def flags() -> List[str]: - cpp_standard = int(os.getenv('DG_JIT_OVERRIDE_CPP_STANDARD', 20)) - return [f'-std=c++{cpp_standard}', - '--ptxas-options=--register-usage-level=10' + - (',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''), - # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases - '--diag-suppress=39,161,174,177,186,940'] - - @staticmethod - def include_dirs() -> List[str]: - return [get_jit_include_dir()] - - @classmethod - def build(cls, name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime: - # Compiler flags - flags = cls.flags() - - # Build signature - enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_JIT_DISABLE_FFMA_INTERLEAVE', 0)) - signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}' - name = f'kernel.{name}.{hash_to_hex(signature)}' - path = os.path.join(get_cache_dir(), name) - - # Check runtime cache or file system hit - global runtime_cache - cached_runtime = runtime_cache.get(path, runtime_cls, name, kwargs) - if cached_runtime is not None: - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Using cached JIT runtime {name} during build') - return cached_runtime - - # Compile into a temporary CU file - os.makedirs(path, exist_ok=True) - cubin_path = os.path.join(path, 'kernel.cubin') - tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin') - - start_time = time.time() - cls.compile(name, code, tmp_cubin_path) - end_time = time.time() - elapsed_time = end_time - start_time - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.') - - # Interleave FFMA reuse - if enable_sass_opt: - interleave_ffma.process(tmp_cubin_path) - - # Atomic replace files - os.replace(tmp_cubin_path, cubin_path) - - # Put cache and return - runtime = runtime_cache.get(path, runtime_cls, name, kwargs, force_enable_cache=True) - assert runtime is not None - return runtime - - -class NVCCCompiler(Compiler): - @staticmethod - def __version__() -> Tuple[int, int]: - _, version = get_nvcc_compiler() - major, minor = map(int, version.split('.')) - return major, minor - - @classmethod - def signature(cls) -> str: - return f'{get_nvcc_compiler()[0]}+{cls.__version__()}' - - @classmethod - def flags(cls) -> List[str]: - cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi'] - return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], - '-gencode=arch=compute_90a,code=sm_90a', - '-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', - f'--compiler-options={",".join(cxx_flags)}'] - - @classmethod - def compile(cls, name: str, code: str, target_path: str) -> None: - # Write the code - path = os.path.join(get_cache_dir(), name) - src_path = os.path.join(path, 'kernel.cu') - put(src_path, code) - command = [get_nvcc_compiler()[0], - src_path, '-o', target_path, - *cls.flags()] - if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)): - print(f'Compiling JIT runtime {name} with command {command}') - - result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - if result.returncode != 0: - print(f'NVCC compilation failed: stdout: {result.stdout}, stderr: {result.stderr}') - assert False, f'Failed to compile {src_path}' - - -class NVRTCCompiler(Compiler): - @staticmethod - def __version__() -> Tuple[int, int]: - res, major, minor = nvrtc.nvrtcVersion() - if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - # Failed to get the actual NVRTC version, use cuda-bindings version instead - major, minor = map(int, cuda.bindings.__version__.split('.')[:2]) - return major, minor - - @classmethod - def signature(cls) -> str: - return f'nvrtc+{cls.__version__()}' - - @staticmethod - def include_dirs() -> List[str]: - if CUDA_HOME is None: - raise RuntimeError('CUDA_HOME is required for NVRTC compilation') - return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')] - - @classmethod - def flags(cls) -> List[str]: - flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], - '--gpu-architecture=sm_90a', '-default-device'] - # NOTES: PCH is vital for compilation speed - if cls.__version__() >= (12, 8): - flags += ['--pch'] - if int(os.getenv('DG_JIT_DEBUG', 0)): - flags += ['--pch-verbose=true'] - return flags - - @classmethod - def compile(cls, name: str, code: str, target_path: str) -> None: - # Create program - code_bytes = bytes(code, 'utf-8') - result, program = nvrtc.nvrtcCreateProgram( - code_bytes, bytes(name, 'utf-8'), 0, [], []) - assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to create program: {result}' - - # Compile - options = [bytes(flag, 'utf-8') for flag in cls.flags()] - if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)): - print(f'Compiling JIT runtime {name} with options: {options}') - compile_result = nvrtc.nvrtcCompileProgram(program, len(options), options)[0] - - # Print compiler log - if int(os.getenv('DG_JIT_DEBUG', 0)) or compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS: - result, log_size = nvrtc.nvrtcGetProgramLogSize(program) - assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log size: {result}' - - log_bytes = bytes(log_size) - result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0] - assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log: {result}' - print(f'Compiler log: {log_bytes.decode("utf-8")}') - - # Exit if failed - assert compile_result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to compile program: {compile_result}' - - # Create CUBIN - result, cubin_size = nvrtc.nvrtcGetCUBINSize(program) - assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN size: {result}' - cubin_bytes = bytes(cubin_size) - result = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0] - assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN: {result}' - - # Write into the file system - put(target_path, cubin_bytes) - - # Destroy handler - assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}' - - -def build(name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime: - compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler - return compiler_cls.build(name, code, runtime_cls, kwargs) diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py deleted file mode 100644 index 7899a2219..000000000 --- a/deep_gemm/jit/interleave_ffma.py +++ /dev/null @@ -1,137 +0,0 @@ -import argparse -import mmap -import os -import re -import subprocess -from torch.utils.cpp_extension import CUDA_HOME - - -def run_cuobjdump(file_path): - command = [f'{CUDA_HOME}/bin/cuobjdump', '-sass', file_path] - result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - assert result.returncode == 0 - return result.stdout - - -def extract_ffma(sass): - lines = sass.splitlines() - collected = [] - current = [] - - arch_name, func_name = 'N/A', 'N/A' - skip_next_line = False - for line in lines: - if 'code for' in line: - arch_name = line.lstrip().lstrip('code for ').rstrip() - elif 'Function :' in line: - func_name = line.lstrip().lstrip('Function :').rstrip() - elif 'FFMA' in line: - current.append(line) - skip_next_line = True - elif skip_next_line: - current.append(line) - skip_next_line = False - else: - if len(current) >= 16: - assert len(current) % 2 == 0 - collected.append((f'{arch_name}::{func_name}', current)) - current = [] - - if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): - print(f'Found {len(collected)} FFMA segments') - return collected - - -def extract_hex_from_line(line): - match = re.search(r'/\*\s*(0x[0-9a-fA-F]+)\s*\*/', line) - assert match - return int(match.group(1), 16) - - -def validate(m, offset, le_bytes, num_lines): - assert len(le_bytes) == num_lines // 2 - assert m[offset:offset + 16] == le_bytes[0] - for i in range(1, num_lines // 2): - if m[offset + i * 16:offset + i * 16 + 16] != le_bytes[i]: - return False - return True - - -def parse_registers(line): - line = re.sub(r'/\*.*?\*/', '', line) - line = line.replace(';', '') - tokens = line.strip().split(',') - registers = [] - for token in tokens: - token = token.strip() - words = token.split() - for word in words: - if word.startswith('R'): - reg = word.split('.')[0] - registers.append(reg) - return registers - - -def modify_segment(m, name, ffma_lines): - num_lines = (len(ffma_lines) * 9 // 16) // 2 * 2 - assert num_lines % 2 == 0 - - le_bytes, new_le_bytes = [], [] - reused_list = [] - dst_reg_set = set() - last_reused, last_dst_reg = False, '' - num_changed = 0 - for i in range(num_lines // 2): - dst_reg = parse_registers(ffma_lines[i * 2])[-2] - low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1] - low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line) - le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) - reused = (high_hex & 0x0800000000000000) != 0 - if reused: - is_first_occurred = dst_reg not in dst_reg_set - if is_first_occurred or (last_reused and dst_reg == last_dst_reg): - # Modify the `reuse` and `yield` bits - assert high_hex & 0x0800200000000000, f'{hex(high_hex)}' - high_hex ^= 0x0800200000000000 - reused = False - num_changed += 1 - else: - reused_list.append(i) - dst_reg_set.add(dst_reg) - new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) - last_reused, last_dst_reg = reused, dst_reg - if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): - print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}') - - # Find the offset - offsets = [] - offset = m.find(le_bytes[0]) - while offset != -1: - offsets.append(offset) - offset = m.find(le_bytes[0], offset + 1) - offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets)) - - # Replace with `new_le_bytes` - for offset in offsets: - for i in range(num_lines // 2): - m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i] - - -def process(path): - if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): - print(f'Processing {path}') - output = run_cuobjdump(path) - segments = extract_ffma(output) - with open(path, 'r+b') as f: - mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE) - for segment in segments: - modify_segment(mm, *segment) - mm.close() - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse') - parser.add_argument('--so', help='Path to the SO file') - args = parser.parse_args() - - process(args.so) diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py deleted file mode 100644 index 7a63bf1ce..000000000 --- a/deep_gemm/jit/runtime.py +++ /dev/null @@ -1,105 +0,0 @@ -import os -import subprocess -import time -import torch -import cuda.bindings.driver as cbd - -from typing import Any, Dict, Optional, Type -from torch.utils.cpp_extension import CUDA_HOME - - -class Runtime: - def __init__(self, path: str) -> None: - self.path = path - self.lib = None - self.kernel = None - assert self.is_path_valid(self.path) - - @staticmethod - def is_path_valid(path: str) -> bool: - # Exists and is a directory - if not os.path.exists(path) or not os.path.isdir(path): - return False - - # Contains all necessary files - files = ['kernel.cubin'] - return all(os.path.exists(os.path.join(path, file)) for file in files) - - @staticmethod - def generate(kwargs: Dict[str, Any]) -> str: - raise NotImplemented - - @staticmethod - def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: - raise NotImplemented - - def __call__(self, **kwargs) -> cbd.CUresult: - # Load CUBIN - if self.kernel is None: - start_time = time.time_ns() - - # Load CUBIN - path = bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8') - result, self.lib = cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0) - assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load library: {result}' - - # Extract the kernel name - # TODO: use `cuda-bindings` API to do this (requires at least 12.8) - command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path] - result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - assert result.returncode == 0 - illegal_names = ['vprintf', '__instantiate_kernel', '__internal', '__assertfail'] - check_illegal = lambda line: any([name in line for name in illegal_names]) - kernel_names = [line.split()[-1] for line in result.stdout.splitlines() - if line.startswith('STT_FUNC') and not check_illegal(line)] - assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}' - - # Load kernel from the library - result, self.kernel = cbd.cuLibraryGetKernel(self.lib, bytes(kernel_names[0], encoding='utf-8')) - assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load kernel: {result}' - - end_time = time.time_ns() - elapsed_time = (end_time - start_time) / 1e6 - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.') - - # noinspection PyArgumentList - return self.launch(self.kernel, kwargs) - - def __del__(self) -> None: - if self.lib is not None: - res = cbd.cuLibraryUnload(self.lib)[0] - if res != cbd.CUresult.CUDA_SUCCESS: - raise Exception(f'Failed to unload library {self.path}: {res}') - - -class RuntimeCache: - def __init__(self) -> None: - self.cache = {} - - def __setitem__(self, path: str, runtime: Runtime) -> None: - self.cache[path] = runtime - - def get(self, path: str, runtime_cls: Type[Runtime], - name: str = '', kwargs: Dict[str, Any] = None, - force_enable_cache: bool = False) -> Optional[Runtime]: - # In Python runtime - if path in self.cache: - return self.cache[path] - - # Already compiled - use_cache = force_enable_cache or not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) - if use_cache and os.path.exists(path) and Runtime.is_path_valid(path): - # Print heuristic for the first time - if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))): - simplified_kwargs = dict() - for key, value in kwargs.items() if kwargs is not None else dict().items(): - value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value - value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value - simplified_kwargs[key] = value - print(f'Put kernel {name} with {simplified_kwargs} into runtime cache') - - runtime = runtime_cls(path) - self.cache[path] = runtime - return runtime - return None diff --git a/deep_gemm/jit_kernels/__init__.py b/deep_gemm/jit_kernels/__init__.py deleted file mode 100644 index f1fa7bb24..000000000 --- a/deep_gemm/jit_kernels/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from .gemm import gemm_fp8_fp8_bf16_nt -from .m_grouped_gemm import ( - m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, - m_grouped_gemm_fp8_fp8_bf16_nt_masked -) -from .wgrad_gemm import ( - wgrad_gemm_fp8_fp8_fp32_nt, - k_grouped_wgrad_gemm_fp8_fp8_fp32_nt -) -from .utils import ( - ceil_div, set_num_sms, get_num_sms, - get_col_major_tma_aligned_tensor, - get_m_alignment_for_contiguous_layout -) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py deleted file mode 100644 index 574f821f7..000000000 --- a/deep_gemm/jit_kernels/gemm.py +++ /dev/null @@ -1,242 +0,0 @@ -import math -import torch -from functools import lru_cache -from typing import Tuple - -from ..jit import build -from .runtime import ( - FP8GemmRuntime, GemmType, - make_2d_tma_a_desc, make_2d_tma_b_desc, - make_2d_tma_d_desc, make_2d_tma_scales_desc) -from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout - - -def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int, - require_divisible: bool = False) -> bool: - divisible = ceil_div(shape_dim, block_dim) % num_tma_multicast == 0 or not require_divisible - return divisible and num_sms % num_tma_multicast == 0 - - -def get_swizzle_mode(block_n: int) -> int: - elem_size = 2 - for mode_bytes in (128, 64, 32): - if (block_n * elem_size) % mode_bytes == 0: - return mode_bytes - return 0 - - -def get_block_n_padding_for_smem_d(block_n: int) -> int: - # NOTES: padding is for solving bank conflicts, but wastes shared memory space - elem_size, requirement = 2, (4, 8) - bank_stride = (block_n * elem_size) // 4 - padding = (requirement[0] - bank_stride) % requirement[1] - return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size - - -def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128, - is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]: - assert block_k == 128 - - # Try swizzle first, as it does not waste shared memory - swizzle_mode = get_swizzle_mode(block_n) - block_n_padding = get_block_n_padding_for_smem_d( - block_n) if swizzle_mode == 0 else 0 - - # NOTES: `scales_b` in a total manner or per-stage manner - smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2) - smem_a_per_stage = block_m * block_k - smem_scales_a_per_stage = block_m * 4 - smem_b_per_stage = block_n * block_k - smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0 - smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0 - smem_barrier = num_stages * 8 * 2 - - smem_size = 0 - smem_size += smem_d - smem_size += num_stages * smem_a_per_stage - smem_size += num_stages * smem_scales_a_per_stage - smem_size += num_stages * smem_b_per_stage - smem_size += num_stages * smem_scales_b_per_stage - smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 - smem_size += smem_barrier - - # Swizzle and padding are not compatible - assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1 - - return smem_size, swizzle_mode, block_n_padding - - -@lru_cache(maxsize=None) -def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, - is_grouped_contiguous: bool = False, is_grouped_masked: bool = False, - is_fp32_out: bool = False, is_wgrad: bool = False) -> \ - Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]: - if not is_grouped_contiguous: - block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ()) - else: - block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, )) - - # Avoid bank conflicts for FP32 output - if is_fp32_out: - block_ns = [x for x in block_ns if x % 16 == 8] - - fix_wave_saturate = lambda x: num_sms if x == 0 else x - get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) - get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) - - # Decide block sizes by waves - best_block_m, best_block_n = None, None - for block_m in block_ms: - # NOTES: the block sizes cannot be too large, so at least one dim less than 128 - for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns): - success = False - num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) - if best_block_m is None or best_block_n is None: - success = True - elif num_waves < best_num_waves: - success = True - elif num_waves == best_num_waves: - # Check last wave utilization - util = get_last_wave_util(block_m, block_n) - best_util = get_last_wave_util(best_block_m, best_block_n) - success = util > best_util - if util == best_util: - # Case 1: same `block_m`, smaller `block_n` (wasted) - success |= block_m == best_block_m and block_n < best_block_n - # Case 2: same `block_n`, smaller `block_m` (wasted) - success |= block_n == best_block_n and block_m < best_block_m - # Case 3: different for both `block_m` and `block_n`, `block_n` larger is better - success |= block_m != best_block_m and block_n > best_block_n - best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) - assert best_block_m is not None and best_block_n is not None - - # Always pick the longest one - # NOTES: for double B scales, the best number of stages may be reduced - best_num_stages, best_smem_config, sm90_capacity = None, None, 232448 - stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (8, 7, 6, 5, 4, 3, 2, 1))) - if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4: - # Unrolling both stages and `num_former_iters` will cause large code size - stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1))) - for num_stages in stage_candidates: - best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad) - if best_smem_config[0] <= sm90_capacity: - best_num_stages = num_stages - break - assert best_smem_config is not None - assert best_num_stages is not None - - # Decide the number of TMA multicasts and whether broadcast on A - best_tma_multicast_config = (1, True) - - # Try to multicast on the larger block side first - # NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even - is_multicast_legal = { - 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked), - 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked, - } - for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): - if m >= 512 and is_multicast_legal[i]: - best_tma_multicast_config = (2, i == 'A') - break - - # Recompute the minimal number of SMs required - # NOTES: less L2 cache usage and less GPU frequency drop - num_waves = get_num_waves(best_block_m, best_block_n) - num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) - num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] - assert num_min_sms <= num_sms - - return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config - - -def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor) -> None: - """ - Perform a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - - Requirements: - LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. - The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 8. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. - - Arguments: - lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, - the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`, - the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`. - out: the BF16 output tensor of shape `[m, n]`, representing the result. - """ - lhs, lhs_scales = lhs - rhs, rhs_scales = rhs - m, k = lhs.shape - n, k_ = rhs.shape - m_, n_ = out.shape - - # Type and shape checks - assert m == m_ and n == n_ and k == k_ - assert n > 0 and k > 0 - assert lhs_scales.shape == (m, ceil_div(k, 128)) - assert rhs_scales.shape == (ceil_div(n, 128), ceil_div(k, 128)) - assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 - assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 - assert out.dtype == torch.bfloat16 - assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 - - # LHS scales must be transposed for TMA loads, but not for RHS scales - # NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels - lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) - assert rhs_scales.is_contiguous() - - # Do nothing if `m` is zero - if m == 0: - return - - # K must be aligned to 128 - aligned_k = ceil_div(k, 128) * 128 - - # Auto-tuning with compilation - num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms) - block_k = 128 - num_tma_threads = 128 - num_math_threads_per_group = 128 - - tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1) - tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1) - tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1]) - tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1) - - kwargs = { - # Templated arguments - 'GEMM_TYPE': GemmType.Normal, - 'NUM_TMA_THREADS': num_tma_threads, - 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, - 'M': m, 'N': n, 'K': aligned_k, - 'NUM_GROUPS': 1, - 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, - 'SWIZZLE_D_MODE': smem_config[1], - 'BLOCK_N_PADDING': smem_config[2], - 'NUM_STAGES': num_stages, - 'NUM_TMA_MULTICAST': tma_multicast_config[0], - 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - # Runtime arguments - 'SCALES_B': rhs_scales, - 'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device), - 'NUM_SMS': num_sms, - 'SMEM_SIZE': smem_config[0], - 'TENSOR_MAP_A': tensor_map_a, - 'TENSOR_MAP_B': tensor_map_b, - 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, - 'TENSOR_MAP_D': tensor_map_d, - 'STREAM': torch.cuda.current_stream().cuda_stream, - 'DEVICE_INDEX': out.device.index - } - - # Generate, build and run the kernel - code = FP8GemmRuntime.generate(kwargs) - runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(**kwargs) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py deleted file mode 100644 index ca2fc79ae..000000000 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ /dev/null @@ -1,205 +0,0 @@ -import torch -from typing import Tuple - -from ..jit import build -from .gemm import get_best_configs -from .runtime import ( - FP8GemmRuntime, GemmType, - make_2d_tma_a_desc, make_2d_tma_b_desc, - make_2d_tma_d_desc, make_2d_tma_scales_desc) -from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms - - -def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor, m_indices: torch.Tensor) -> None: - """ - Perform a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - - Requirements: - LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. - On the M axis, inputs are grouped into several batches, of which batch sizes aligned to - `get_m_alignment_for_contiguous_layout()` (128). - - Arguments: - lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`, - the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`, - the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. - out: the BF16 output tensor of shape `[m_sum, n]`, representing the result. - m_indices: a tensor of shape `[m_sum]` with type `torch.int`. - `m_indices[i]` records the group which the i-th row of the LHS belongs to, - which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`. - Values of `m_indices` in every-m-alignment-block must also be the same. - """ - lhs, lhs_scales = lhs - rhs, rhs_scales = rhs - m, k = lhs.shape - num_groups, n, k_ = rhs.shape - m_, n_ = out.shape - m__ = m_indices.numel() - - # Type and shape checks - assert m == m_ == m__ and k == k_ and n == n_ - assert lhs_scales.shape == (m, ceil_div(k, 128)) - assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128)) - assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 - assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 - assert out.dtype == torch.bfloat16 - assert m_indices.dtype == torch.int32 - assert lhs.is_contiguous() and rhs.is_contiguous() - assert out.is_contiguous() and m_indices.is_contiguous() - - # LHS scales must be transposed for TMA load, but not for RHS scales - lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) - assert rhs_scales.is_contiguous() - - # Do nothing if `m` is zero - if m == 0: - return - - # Auto-tuning with compilation - num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( - m, n, k, 1, num_sms, is_grouped_contiguous=True) - block_k = 128 - num_tma_threads = 128 - num_math_threads_per_group = 128 - - tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedContiguous, lhs, m, k, k, block_m, block_k, num_groups) - tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedContiguous, rhs, n, k, k, block_n, block_k, num_groups) - tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedContiguous, out, m, n, n, block_m, block_n, num_groups, smem_config[1]) - tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups) - - kwargs = { - # Templated arguments - 'NUM_TMA_THREADS': num_tma_threads, - 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, - 'M': m, 'N': n, 'K': k, - 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, - 'SWIZZLE_D_MODE': smem_config[1], - 'BLOCK_N_PADDING': smem_config[2], - 'NUM_GROUPS': num_groups, - 'NUM_STAGES': num_stages, - 'NUM_TMA_MULTICAST': tma_multicast_config[0], - 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - 'GEMM_TYPE': GemmType.GroupedContiguous, - # Runtime arguments - 'SCALES_B': rhs_scales, - 'GROUPED_LAYOUT': m_indices, - 'NUM_SMS': num_sms, - 'SMEM_SIZE': smem_config[0], - 'TENSOR_MAP_A': tensor_map_a, - 'TENSOR_MAP_B': tensor_map_b, - 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, - 'TENSOR_MAP_D': tensor_map_d, - 'STREAM': torch.cuda.current_stream().cuda_stream, - 'DEVICE_INDEX': out.device.index - } - - # Generate, build and run the kernel - code = FP8GemmRuntime.generate(kwargs) - runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(**kwargs) - - -def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None: - """ - Perform a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - - Requirements: - LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. - Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch - should be separately transposed. - - Arguments: - lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, - the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`. - The second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. - out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result. - masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute - in the i-th group. - expected_m: a value hint (which is a value on CPU) for the M expectation of each batch, - correctly setting this value may lead to better performance. - """ - lhs, lhs_scales = lhs - rhs, rhs_scales = rhs - num_groups, m, k = lhs.shape - num_groups_, n, k_ = rhs.shape - num_groups__, m_, n_ = out.shape - num_groups___ = masked_m.numel() - - # Type and shape checks - assert num_groups == num_groups_ == num_groups__ == num_groups___ - assert m == m_ and n == n_ and k == k_ - assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 - assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128)) - assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128)) - assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 - assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 - assert out.dtype == torch.bfloat16 - assert masked_m.dtype == torch.int32 - assert lhs.is_contiguous() and rhs.is_contiguous() - assert out.is_contiguous() and masked_m.is_contiguous() - - # LHS scales must be transposed for TMA load, but not for RHS scales - lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) - assert rhs_scales.is_contiguous() - - # Auto-tuning with compilation - num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( - expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) - - # Extra checks for TMA store - if num_groups > 1 and m > block_m: - assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' - - block_k = 128 - num_tma_threads = 128 - num_math_threads_per_group = 128 - - tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedMasked, lhs, m, k, k, block_m, block_k, num_groups) - tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedMasked, rhs, n, k, k, block_n, block_k, num_groups) - tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedMasked, out, m, n, n, block_m, block_n, num_groups, smem_config[1]) - tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups) - - kwargs = { - # Templated arguments - 'NUM_TMA_THREADS': num_tma_threads, - 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, - 'M': m, 'N': n, 'K': k, - 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, - 'SWIZZLE_D_MODE': smem_config[1], - 'BLOCK_N_PADDING': smem_config[2], - 'NUM_GROUPS': num_groups, - 'NUM_STAGES': num_stages, - 'NUM_TMA_MULTICAST': tma_multicast_config[0], - 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - 'GEMM_TYPE': GemmType.GroupedMasked, - # Runtime arguments - 'SCALES_B': rhs_scales, - 'GROUPED_LAYOUT': masked_m, - 'NUM_SMS': num_sms, - 'SMEM_SIZE': smem_config[0], - 'TENSOR_MAP_A': tensor_map_a, - 'TENSOR_MAP_B': tensor_map_b, - 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, - 'TENSOR_MAP_D': tensor_map_d, - 'STREAM': torch.cuda.current_stream().cuda_stream, - 'DEVICE_INDEX': out.device.index - } - - # Generate, build and run the kernel - code = FP8GemmRuntime.generate(kwargs) - runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(**kwargs) diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py deleted file mode 100644 index e65e85aa8..000000000 --- a/deep_gemm/jit_kernels/runtime.py +++ /dev/null @@ -1,318 +0,0 @@ -import ctypes -import os -import enum -import torch -import cuda.bindings.driver as cbd -from typing import Any, Dict, Tuple - -from .utils import get_tma_aligned_size -from ..jit.runtime import Runtime - - -class GemmType(enum.Enum): - Normal = 0 - GroupedContiguous = 1 - GroupedMasked = 2 - - def __str__(self) -> str: - return { - 0: 'Normal', - 1: 'GroupedContiguous', - 2: 'GroupedMasked', - }[self.value] - - -tmap_type_map: Dict[Any, str] = { - torch.int8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.int16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, - torch.int32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32, - torch.int64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64, - torch.uint8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.uint16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, - torch.uint32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32, - torch.uint64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64, - torch.float32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32, - torch.float16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, - torch.bfloat16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, - torch.float8_e4m3fn: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.float8_e4m3fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.float8_e5m2: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.float8_e5m2fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, -} - -swizzle_type_map = { - 0: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, - 32: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B, - 64: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B, - 128: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, -} - - -def get_num_math_warpgroups(block_m: int) -> int: - return 1 if block_m == 64 else 2 - - -def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int: - assert num_math_threads_per_group == 128, 'Only support 128 threads per math group' - return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads - - -def make_2d_tma_copy_desc(t: torch.Tensor, - gmem_dims: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], gmem_outer_stride: cbd.cuuint64_t, - smem_dims: Tuple[cbd.cuuint32_t, cbd.cuuint32_t], - swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap: - tensor_dtype = tmap_type_map[t.dtype] - res, tensor_map = cbd.cuTensorMapEncodeTiled( - tensor_dtype, - 2, - t.data_ptr(), - gmem_dims, - (gmem_outer_stride,), - smem_dims, - (cbd.cuuint32_t(1), cbd.cuuint32_t(1)), - cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, - swizzle_type, - cbd.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B, - cbd.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, - ) - - if res != cbd.CUresult.CUDA_SUCCESS: - raise Exception(f'Failed to encode tensor map: {res}') - return tensor_map - - -def make_2d_tma_desc(t: torch.Tensor, - gmem_inner_dim: int, gmem_outer_dim: int, gmem_outer_stride: int, - smem_inner_dim: int, smem_outer_dim: int, - swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap: - gmem_dim = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim)) - smem_dim = (cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim)) - return make_2d_tma_copy_desc(t, gmem_dim, cbd.cuuint64_t(gmem_outer_stride * t.element_size()), smem_dim, swizzle_type) - - -def make_2d_tma_a_desc(gemm_type: GemmType, t: torch.Tensor, - shape_m: int, shape_k: int, m_stride: int, - block_m: int, block_k: int, - num_groups: int) -> cbd.CUtensorMap: - return make_2d_tma_desc(t, - shape_k, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride, - block_k, block_m) - - -def make_2d_tma_b_desc(gemm_type: GemmType, t: torch.Tensor, - shape_n: int, shape_k: int, n_stride: int, - block_n: int, block_k: int, - num_groups: int) -> cbd.CUtensorMap: - return make_2d_tma_desc(t, - shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), n_stride, - block_k, block_n) - - -def make_2d_tma_d_desc(gemm_type: GemmType, t: torch.Tensor, - shape_m: int, shape_n: int, m_stride: int, - block_m: int, block_n: int, - num_groups: int, - swizzle_mode: int) -> cbd.CUtensorMap: - # Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode` - # bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required - return make_2d_tma_desc(t, - shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride, - block_n if swizzle_mode == 0 else swizzle_mode // t.element_size(), block_m, - swizzle_type_map[swizzle_mode]) - - -def make_2d_tma_scales_desc(gemm_type: GemmType, t: torch.Tensor, - shape_mn: int, shape_k: int, - block_mn: int, block_k: int, - num_groups: int) -> cbd.CUtensorMap: - # Make TMA aligned to 16 bytes - shape_mn = get_tma_aligned_size(shape_mn, t.element_size()) - return make_2d_tma_desc(t, - shape_mn, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_mn, - block_mn, 1, - cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) - - -class FP8GemmRuntime(Runtime): - def __init__(self, path: str) -> None: - super().__init__(path) - - @staticmethod - def generate(kwargs: Dict[str, Any]) -> str: - code = f''' -#ifdef __CUDACC_RTC__ -#include -#else -#include -#include -#endif - -#include -#include - -#include - -using namespace deep_gemm; - -static void __instantiate_kernel() {{ - auto ptr = reinterpret_cast(&fp8_gemm_kernel< - {kwargs['N']}, - {kwargs['K']}, - {kwargs['BLOCK_M']}, - {kwargs['BLOCK_N']}, - {kwargs['BLOCK_K']}, - {kwargs['BLOCK_N_PADDING']}, - {kwargs['SWIZZLE_D_MODE']}, - {kwargs['NUM_GROUPS']}, - {kwargs['NUM_STAGES']}, - {kwargs['NUM_TMA_THREADS']}, - {kwargs['NUM_MATH_THREADS_PER_GROUP']}, - {kwargs['NUM_TMA_MULTICAST']}, - {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, - GemmType::{kwargs['GEMM_TYPE']} - >); -}}; -''' - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Generated FP8 GEMM code:\n{code}') - return code - - # noinspection PyMethodOverriding - @staticmethod - def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: - num_tma_threads = 128 - num_math_threads_per_group = 128 - - result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0] - assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}' - - attr_val = cbd.CUlaunchAttributeValue() - attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST'] - attr_val.clusterDim.y = 1 - attr_val.clusterDim.z = 1 - attr = cbd.CUlaunchAttribute() - attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION - attr.value = attr_val - - config = cbd.CUlaunchConfig() - config.numAttrs = 1 - config.attrs = [attr] - config.gridDimX = kwargs['NUM_SMS'] - config.gridDimY = 1 - config.gridDimZ = 1 - config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M']) - config.blockDimY = 1 - config.blockDimZ = 1 - config.sharedMemBytes = kwargs['SMEM_SIZE'] - config.hStream = kwargs['STREAM'] - - arg_values = ( - kwargs['SCALES_B'].data_ptr(), - kwargs['GROUPED_LAYOUT'].data_ptr(), - kwargs['M'], - kwargs['TENSOR_MAP_A'], - kwargs['TENSOR_MAP_B'], - kwargs['TENSOR_MAP_SCALES_A'], - kwargs['TENSOR_MAP_D'], - ) - arg_types = ( - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_uint32, - None, - None, - None, - None, - ) - return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) - - -class FP8WGradGemmRuntime(Runtime): - def __init__(self, path: str) -> None: - super().__init__(path) - - @staticmethod - def generate(kwargs: Dict[str, Any]) -> str: - code = f''' -#ifdef __CUDACC_RTC__ -#include -#else -#include -#include -#endif - -#include -#include - -#include - -using namespace deep_gemm; - -static void __instantiate_kernel() {{ - auto ptr = reinterpret_cast(&fp8_wgrad_gemm_kernel< - {kwargs['M']}, - {kwargs['N']}, - {kwargs['BLOCK_M']}, - {kwargs['BLOCK_N']}, - {kwargs['BLOCK_K']}, - {kwargs['NUM_STAGES']}, - {kwargs['NUM_LAST_STAGES']}, - {kwargs['NUM_TMA_THREADS']}, - {kwargs['NUM_MATH_THREADS_PER_GROUP']}, - {kwargs['NUM_TMA_MULTICAST']}, - {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'} - >); -}}; -''' - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Generated FP8 WGrad GEMM code:\n{code}') - return code - - # noinspection PyMethodOverriding - @staticmethod - def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: - num_tma_threads = 128 - num_math_threads_per_group = 128 - - result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0] - assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}' - - attr_val = cbd.CUlaunchAttributeValue() - attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST'] - attr_val.clusterDim.y = 1 - attr_val.clusterDim.z = 1 - attr = cbd.CUlaunchAttribute() - attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION - attr.value = attr_val - - config = cbd.CUlaunchConfig() - config.numAttrs = 1 - config.attrs = [attr] - config.gridDimX = kwargs['NUM_SMS'] - config.gridDimY = 1 - config.gridDimZ = 1 - config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M']) - config.blockDimY = 1 - config.blockDimZ = 1 - config.sharedMemBytes = kwargs['SMEM_SIZE'] - config.hStream = kwargs['STREAM'] - - arg_values = ( - kwargs['K'], - kwargs['TENSOR_MAP_A'], - kwargs['TENSOR_MAP_B'], - kwargs['TENSOR_MAP_SCALES_A'], - kwargs['TENSOR_MAP_SCALES_B'], - kwargs['TENSOR_MAP_D'], - ) - arg_types = ( - ctypes.c_uint32, - None, - None, - None, - None, - None, - ) - return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py deleted file mode 100644 index c6da56b0e..000000000 --- a/deep_gemm/jit_kernels/utils.py +++ /dev/null @@ -1,109 +0,0 @@ -import torch - -_num_sms = None - - -def set_num_sms(num_sms: int) -> None: - """ - Set the maximum SM count for all GEMM kernels to use. - - Arguments: - num_sms: the desired maximum SM count for all GEMM kernels to use. - """ - global _num_sms - assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count - _num_sms = num_sms - - -def get_num_sms() -> int: - """ - Get the current maximum limit of SM count for all GEMM kernels to use. - If the count is never specified, the function will return the number of device SMs. - - Returns: - Current maximum limit of SM count for all GEMM kernels to use. - """ - global _num_sms - if _num_sms is None: - _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count - return _num_sms - - -def ceil_div(x: int, y: int) -> int: - """ - Perform ceiling division of two integers. - - Args: - x: the dividend. - y: the divisor. - - Returns: - The result of the ceiling division. - """ - return (x + y - 1) // y - - -def get_m_alignment_for_contiguous_layout(): - """ - When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis. - Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well - with GEMM block shape. - - Returns: - Group-level alignment requirement for grouped contiguous layout, which is always 128. - """ - return 128 - - -def get_tma_aligned_size(x: int, element_size: int) -> int: - """ - Global memory address of TMA must be 16-byte aligned. - Since we use column-major layout for the LHS scaling tensor, - the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. - - Arguments: - x: original M-axis shape of the LHS scaling tensor. - element_size: element size of the LHS scaling tensor. - - Returns: - M-axis shape of the LHS scaling tensor after padding. - """ - tma_alignment_bytes = 16 - assert tma_alignment_bytes % element_size == 0 - alignment = tma_alignment_bytes // element_size - return ceil_div(x, alignment) * alignment - - -def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: - """ - Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary. - If the input tensor is already column-major layout and 16-byte aligned along the M axis - (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing. - - Arguments: - x: usually the LHS scaling tensor in GEMM. - - Returns: - The LHS scaling tensor of TMA-aligned transposed format. - """ - # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA - assert x.dim() in (2, 3) - remove_dim = False - m, n = x.shape[-2], x.shape[-1] - aligned_m = get_tma_aligned_size(m, x.element_size()) - if x.dim() == 2: - if x.stride(0) == 1 and x.stride(1) == aligned_m: - return x - x, remove_dim = x.unsqueeze(0), True - - b = x.shape[0] - - # The last kernel gives a column-major TMA aligned layout - if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: - return x.squeeze(0) if remove_dim else x - - # Normal layout requires transposing - aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) - aligned_x[:, :m, :] = x - aligned_x = aligned_x[:, :m, :] - return aligned_x.squeeze(0) if remove_dim else aligned_x diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py deleted file mode 100644 index 00b8cd100..000000000 --- a/deep_gemm/jit_kernels/wgrad_gemm.py +++ /dev/null @@ -1,158 +0,0 @@ -import torch -from typing import List, Tuple - -from ..jit import build -from .runtime import ( - FP8WGradGemmRuntime, GemmType, - make_2d_tma_a_desc, make_2d_tma_b_desc, - make_2d_tma_d_desc, make_2d_tma_scales_desc) -from .gemm import get_best_configs -from .utils import ceil_div, get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size - - -def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor): - """ - Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. - Results will be accumulated into the output tensor. - - Requirements: - LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. - The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling and RHS scaling tensor require a TMA-aligned transposed format. - If your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations. - - Arguments: - lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, - the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`, - the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, ⌈k / 128⌉]`. - out: the FP32 output tensor of shape `[m, n]`, which will be accumulated. - """ - lhs, lhs_scales = lhs - rhs, rhs_scales = rhs - m, k = lhs.shape - n, k_ = rhs.shape - m_, n_ = out.shape - - # Type and shape checks - assert m == m_ and n == n_ and k == k_ - assert n > 0 and m > 0 - assert lhs_scales.shape == (m, ceil_div(k, 128)) or lhs_scales.shape == (ceil_div(k, 128), m) - assert rhs_scales.shape == (n, ceil_div(k, 128)) or rhs_scales.shape == (ceil_div(k, 128), n) - assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 - assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 - assert out.dtype == torch.float - assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 - - # LHS and RHS scales must be transposed for TMA load - # NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels - def get_valid_scales(scales: torch.Tensor, mn: int): - if scales.shape == (ceil_div(k, 128), mn): - # For k-grouped GEMMs - scales = scales.permute(1, 0) - assert get_tma_aligned_size(mn, 4) == scales.stride(1) == mn - else: - scales = get_col_major_tma_aligned_tensor(scales) - return scales - - lhs_scales = get_valid_scales(lhs_scales, m) - rhs_scales = get_valid_scales(rhs_scales, n) - - # Do nothing if `k` is zero - if k == 0: - return - - # K must be aligned to 128 - aligned_k = ceil_div(k, 128) * 128 - - # Auto-tuning with compilation - num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( - m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True) - num_last_stages = ceil_div(k, 128) % num_stages - block_k = 128 - num_tma_threads = 128 - num_math_threads_per_group = 128 - - tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1) - tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1) - tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1]) - tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1) - tensor_map_scales_b = make_2d_tma_scales_desc(GemmType.Normal, rhs_scales, n, k, block_n, block_k, 1) - - kwargs = { - # Templated arguments - 'GEMM_TYPE': GemmType.Normal, - 'NUM_TMA_THREADS': num_tma_threads, - 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, - 'M': m, 'N': n, 'K': aligned_k, - 'NUM_GROUPS': 1, - 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, - 'NUM_STAGES': num_stages, - 'NUM_LAST_STAGES': num_last_stages, - 'NUM_TMA_MULTICAST': tma_multicast_config[0], - 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - # Runtime arguments - 'NUM_SMS': num_sms, - 'SMEM_SIZE': smem_config[0], - 'TENSOR_MAP_A': tensor_map_a, - 'TENSOR_MAP_B': tensor_map_b, - 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, - 'TENSOR_MAP_SCALES_B': tensor_map_scales_b, - 'TENSOR_MAP_D': tensor_map_d, - 'STREAM': torch.cuda.current_stream().cuda_stream, - 'DEVICE_INDEX': out.device.index - } - - # Generate, build and run the kernel - code = FP8WGradGemmRuntime.generate(kwargs) - runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs) - runtime(**kwargs) - - -def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor, - batch_sizes: List[int]): - """ - Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. - Results will be accumulated into the output tensor. - - Requirements: - This function handles multiple batches with varying k-dimensions, processing each batch sequentially. - Each batch's LHS, RHS, and output tensors must be contiguous. - The RHS and RHS scaling factors are required to be transposed. - The LHS scaling and RHS scaling tensors require a TMA-aligned transposed format. - - Arguments: - lhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data, - and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows. - The second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`, - representing the per-128-channel scaling factors. - rhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data, - and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows. - The second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`, - representing the per-128-channel scaling factors. - out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated. - batch_sizes: A list of integers specifying the k-dimension for each batch. - """ - lhs, lhs_scales = lhs[0].view(-1), lhs[1] - rhs, rhs_scales = rhs[0].view(-1), rhs[1] - num_batches, m, n = out.shape - - lhs_offset, rhs_offset, scales_offset = 0, 0, 0 - - for i in range(num_batches): - k = batch_sizes[i] - lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k) - rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k) - lhs_scales_slice = lhs_scales[scales_offset:scales_offset + ceil_div(k, 128)] - rhs_scales_slice = rhs_scales[scales_offset:scales_offset + ceil_div(k, 128)] - wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i]) - - lhs_offset += m * k - rhs_offset += n * k - scales_offset += ceil_div(k, 128) diff --git a/deep_gemm/testing/__init__.py b/deep_gemm/testing/__init__.py new file mode 100644 index 000000000..2537dbf1a --- /dev/null +++ b/deep_gemm/testing/__init__.py @@ -0,0 +1,3 @@ +from . import bench, numeric +from .bench import * +from .numeric import * diff --git a/deep_gemm/utils.py b/deep_gemm/testing/bench.py similarity index 78% rename from deep_gemm/utils.py rename to deep_gemm/testing/bench.py index 55a9affaf..7e77866d9 100644 --- a/deep_gemm/utils.py +++ b/deep_gemm/testing/bench.py @@ -1,8 +1,6 @@ import os import sys -import time import torch -import torch.distributed as dist def bench(fn, num_warmups: int = 5, num_tests: int = 10, @@ -31,7 +29,7 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10, end_event.record() torch.cuda.synchronize() - return start_event.elapsed_time(end_event) / num_tests + return start_event.elapsed_time(end_event) / num_tests / 1e3 class empty_suppress: @@ -77,8 +75,9 @@ def __exit__(self, *_): self.errnull_file.close() -def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, - trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True, +def bench_kineto(fn, kernel_names, num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: str = None, flush_l2: bool = True, with_multiple_kernels: bool = False): # Conflict with Nsight Systems using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0)) @@ -96,12 +95,6 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() with profiler: for i in range(2): - # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead - if barrier_comm_profiling: - lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') - rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') - lhs @ rhs - dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) for _ in range(num_tests): if flush_l2: torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() @@ -116,7 +109,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: # Parse the profiling table assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) - is_tupled = isinstance(kernel_names, tuple) + is_tuple = isinstance(kernel_names, tuple) prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names assert all([isinstance(name, str) for name in kernel_names]) @@ -145,21 +138,4 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: break kernel_times.append(total_time / total_num) - return tuple(kernel_times) if is_tupled else kernel_times[0] - - -def calc_diff(x, y): - x, y = x.double(), y.double() - denominator = (x * x + y * y).sum() - sim = 2 * (x * y).sum() / denominator - return 1 - sim - - -def count_bytes(tensors): - total = 0 - for t in tensors: - if isinstance(t, tuple): - total += count_bytes(t) - else: - total += t.numel() * t.element_size() - return total + return tuple(kernel_times) if is_tuple else kernel_times[0] diff --git a/deep_gemm/testing/numeric.py b/deep_gemm/testing/numeric.py new file mode 100644 index 000000000..d06a03b9b --- /dev/null +++ b/deep_gemm/testing/numeric.py @@ -0,0 +1,19 @@ +import torch +from typing import Iterable + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def count_bytes(*tensors): + total = 0 + for t in tensors: + if isinstance(t, (tuple, list)): + total += count_bytes(*t) + elif t is not None: + total += t.numel() * t.element_size() + return total diff --git a/deep_gemm/utils/__init__.py b/deep_gemm/utils/__init__.py new file mode 100644 index 000000000..e8f859a20 --- /dev/null +++ b/deep_gemm/utils/__init__.py @@ -0,0 +1,3 @@ +from . import math, layout +from .layout import * +from .math import * diff --git a/deep_gemm/utils/layout.py b/deep_gemm/utils/layout.py new file mode 100644 index 000000000..ac8c070b1 --- /dev/null +++ b/deep_gemm/utils/layout.py @@ -0,0 +1,11 @@ +from deep_gemm_cpp import ( + get_tma_aligned_size, + get_mk_alignment_for_contiguous_layout, + get_mn_major_tma_aligned_tensor, + get_mn_major_tma_aligned_packed_ue8m0_tensor, + get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor +) + +# Some alias +get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout +get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout diff --git a/deep_gemm/utils/math.py b/deep_gemm/utils/math.py new file mode 100644 index 000000000..eeef37739 --- /dev/null +++ b/deep_gemm/utils/math.py @@ -0,0 +1,66 @@ +import torch +from typing import Tuple + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def ceil_to_ue8m0(x: torch.Tensor): + assert x.view(-1).amax().item() > 0 + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), sf + + +def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(0) % 128 == 0 + m, n = x.shape + x_view = x.view(-1, 128, n) + x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf + + +def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2)) + + +def per_tensor_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + x_amax = x.abs().float().amax().clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf + + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() \ No newline at end of file diff --git a/develop.sh b/develop.sh new file mode 100755 index 000000000..3a71e2493 --- /dev/null +++ b/develop.sh @@ -0,0 +1,25 @@ +# Change current directory into project root +original_dir=$(pwd) +script_dir=$(realpath "$(dirname "$0")") +cd "$script_dir" + +# Link CUTLASS includes +ln -sf $script_dir/third-party/cutlass/include/cutlass deep_gemm/include +ln -sf $script_dir/third-party/cutlass/include/cute deep_gemm/include + +# Remove old dist file, build files, and build +rm -rf build dist +rm -rf *.egg-info +python setup.py build + +# Find the .so file in build directory and create symlink in current directory +so_file=$(find build -name "*.so" -type f | head -n 1) +if [ -n "$so_file" ]; then + ln -sf "$so_file" . +else + echo "Error: No SO file found in build directory" >&2 + exit 1 +fi + +# Open users' original directory +cd "$original_dir" diff --git a/figures/design.png b/figures/design.png deleted file mode 100644 index b3761d60e..000000000 Binary files a/figures/design.png and /dev/null differ diff --git a/indexing/main.cu b/indexing/main.cu deleted file mode 100644 index 5b15256ad..000000000 --- a/indexing/main.cu +++ /dev/null @@ -1,8 +0,0 @@ -#include "deep_gemm/fp8_gemm.cuh" -#include "deep_gemm/fp8_wgrad_gemm.cuh" - -using namespace deep_gemm; - -int main() { - return 0; -} diff --git a/install.sh b/install.sh new file mode 100755 index 000000000..0cfed3ff3 --- /dev/null +++ b/install.sh @@ -0,0 +1,15 @@ +pip uninstall -y deep_gemm + +# Change current directory into project root +original_dir=$(pwd) +script_dir=$(realpath "$(dirname "$0")") +cd "$script_dir" + +# Remove old dist file, build files, and install +rm -rf build dist +rm -rf *.egg-info +python setup.py bdist_wheel +pip install dist/*.whl + +# Open users' original directory +cd "$original_dir" diff --git a/setup.py b/setup.py index b39efd03f..8ececfcd2 100644 --- a/setup.py +++ b/setup.py @@ -2,34 +2,34 @@ import setuptools import shutil import subprocess +import torch +from setuptools import find_packages from setuptools.command.build_py import build_py -from setuptools.command.develop import develop +from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME current_dir = os.path.dirname(os.path.realpath(__file__)) -jit_include_dirs = ('deep_gemm/include/deep_gemm', ) -third_party_include_dirs = ( +cxx_flags = ['-std=c++17', '-O3', '-fPIC', '-Wno-psabi', '-Wno-deprecated-declarations', + f'-D_GLIBCXX_USE_CXX11_ABI={int(torch.compiled_with_cxx11_abi())}'] +sources = ['csrc/python_api.cpp'] +build_include_dirs = [ + f'{CUDA_HOME}/include', + 'deep_gemm/include', + 'third-party/cutlass/include', + 'third-party/fmt/include', +] +build_libraries = ['cuda', 'cudart', 'nvrtc'] +build_library_dirs = [ + f'{CUDA_HOME}/lib64', + f'{CUDA_HOME}/lib64/stubs' +] +third_party_include_dirs = [ 'third-party/cutlass/include/cute', 'third-party/cutlass/include/cutlass', -) +] - -class PostDevelopCommand(develop): - def run(self): - develop.run(self) - self.make_jit_include_symlinks() - - @staticmethod - def make_jit_include_symlinks(): - # Make symbolic links of third-party include directories - for d in third_party_include_dirs: - dirname = d.split('/')[-1] - src_dir = f'{current_dir}/{d}' - dst_dir = f'{current_dir}/deep_gemm/include/{dirname}' - assert os.path.exists(src_dir) - if os.path.exists(dst_dir): - assert os.path.islink(dst_dir) - os.unlink(dst_dir) - os.symlink(src_dir, dst_dir, target_is_directory=True) +# Use driver API for older CUDA compatibility +if int(os.environ.get('DG_JIT_USE_DRIVER_API', '0')): + cxx_flags.append('-DDG_JIT_USE_DRIVER_API') class CustomBuildPy(build_py): @@ -37,9 +37,21 @@ def run(self): # First, prepare the include directories self.prepare_includes() - # Then run the regular build + # Second, make clusters' cache setting default into `envs.py` + self.generate_default_envs() + + # Finally, run the regular build build_py.run(self) + def generate_default_envs(self): + code = '# Pre-installed environment variables\n' + code += 'persistent_envs = dict()\n' + for name in ('DG_JIT_CACHE_DIR', 'DG_JIT_PRINT_COMPILER_COMMAND', 'DG_JIT_CPP_STANDARD'): + code += f"persistent_envs['{name}'] = '{os.environ[name]}'\n" if name in os.environ else '' + + with open(os.path.join(self.build_lib, 'deep_gemm', 'envs.py'), 'w') as f: + f.write(code) + def prepare_includes(self): # Create temporary build directory instead of modifying package directory build_include_dir = os.path.join(self.build_lib, 'deep_gemm/include') @@ -67,19 +79,28 @@ def prepare_includes(self): except: revision = '' + # noinspection PyTypeChecker setuptools.setup( name='deep_gemm', - version='1.0.0' + revision, - packages=['deep_gemm', 'deep_gemm/jit', 'deep_gemm/jit_kernels'], + version='2.0.0' + revision, + packages=find_packages('.'), package_data={ 'deep_gemm': [ - 'include/deep_gemm/*', + 'include/deep_gemm/**/*', 'include/cute/**/*', 'include/cutlass/**/*', ] }, + ext_modules=[ + CUDAExtension(name='deep_gemm_cpp', + sources=sources, + include_dirs=build_include_dirs, + libraries=build_libraries, + library_dirs=build_library_dirs, + extra_compile_args=cxx_flags) + ], + zip_safe=False, cmdclass={ - 'develop': PostDevelopCommand, 'build_py': CustomBuildPy, }, ) diff --git a/tests/generators.py b/tests/generators.py new file mode 100644 index 000000000..dae223f10 --- /dev/null +++ b/tests/generators.py @@ -0,0 +1,305 @@ +import enum +import random +import torch +from typing import Generator, Tuple, List + +from deep_gemm.utils import ( + align, ceil_div, + per_token_cast_to_fp8, per_channel_cast_to_fp8, per_block_cast_to_fp8, per_tensor_cast_to_fp8, + get_mk_alignment_for_contiguous_layout +) + + +class KernelType(enum.Enum): + # For SM100 GEMMs + Kernel1D1D = 0 + Kernel1D2D = 1 + KernelNoSF = 2 + + def is_1d1d(self): + return self.value == 0 + + def is_1d2d(self): + return self.value == 1 + + def is_nosf(self): + return self.value == 2 + + +class MajorTypeAB(enum.Enum): + KMajor = 0 + MNMajor = 1 + + def is_k_major(self): + return self.value == 0 + + def is_mn_major(self): + return self.value == 1 + + +def get_arch_major() -> int: + major, minor = torch.cuda.get_device_capability() + return major + + +def get_ue8m0_usage(kernel_type: KernelType) -> bool: + if get_arch_major() == 9: + return False + return kernel_type.is_1d1d() + + +def get_kernel_types(use_bf16: bool = False) -> tuple: + if use_bf16: + return (KernelType.KernelNoSF, ) + return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D) + + +def get_out_dtype() -> tuple: + return (torch.bfloat16, ) if get_arch_major() == 9 else (torch.bfloat16, torch.float) + + +def get_major_ab(freeze_a: bool) -> tuple: + # TODO: test other major-ness for SM90 BF16 GEMMs + if get_arch_major() == 9: + return ((MajorTypeAB.KMajor, MajorTypeAB.KMajor), ) + if freeze_a: + return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor) + return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor), \ + (MajorTypeAB.MNMajor, MajorTypeAB.KMajor), (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) + + +def enumerate_normal(use_bf16: bool = False) -> Generator: + for kernel_type in get_kernel_types(use_bf16): + for m in (128, 4096): + for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]: + for major_a, major_b in get_major_ab(False): + for out_dtype in get_out_dtype(): + for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True): + yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype + + +def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator: + for kernel_type in get_kernel_types(use_bf16): + for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)): + for major_a, major_b in get_major_ab(True): + yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b + + +def enumerate_m_grouped_masked() -> Generator: + max_m = 4096 + for kernel_type in get_kernel_types(): + for num_groups, m in ((1, 1024), (2, 512), (4, 256)): + for n, k in ((4096, 7168), (7168, 2048), ): + yield kernel_type, num_groups, max_m, m, n, k + + +def enumerate_k_grouped_contiguous(): + # TODO: support SM90 kernels + if get_arch_major() == 9: + return [] + + # Must with FP32 accumulation and 1D1D kernels + for num_groups, m, n, expected_k_per_group in (( 4, 4096, 7168, 8192), ( 4, 7168, 2048, 8192), # EP64 + ( 8, 4096, 7168, 4096), ( 8, 7168, 2048, 4096), # EP32 + (16, 4096, 7168, 2048), (16, 7168, 2048, 2048)): # EP16 + ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)] + yield num_groups, m, n, ks, expected_k_per_group + + +def enumerate_sf_layout(): + for use_ue8m0 in (False, True): + for with_transpose in (True, False): + for mn in (4096, 4097, 8192): + for k in (128, 7168, 7296): + for num_groups in (1, 2, 4): + yield mn, k, with_transpose, use_ue8m0, num_groups + + +def enumerate_k_grouped_sf_layout(): + alignment = get_mk_alignment_for_contiguous_layout() + assert alignment % 128 == 0 + for mn in (4096, 7168): + for num_groups, avg_k in ((16, 2048), (8, 4096), (72, 384), (128, 256)): + ks = [align(int(random.uniform(0.7, 1.3) * avg_k), alignment) for _ in range(num_groups)] + yield mn, ks, num_groups + + +def enumerate_transpose(): + for mn in (64, 4096, 16384): + for delta in (0, 101, 202, 303): + for k in (128, 1024, 4096, 9984, 16384): + yield mn + delta, k + + +def generate_normal(m: int, n: int, k: int, + major_a: MajorTypeAB, major_b: MajorTypeAB, + accumulate: bool, out_dtype: torch.dtype, + use_ue8m0: bool = False, use_bf16: bool = False): + a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \ + torch.empty((m, n), device='cuda', dtype=out_dtype) + c = d if accumulate else None + ref_d = (a.float() @ b.float().t() + (c if accumulate else 0)).to(out_dtype) + + if use_bf16: + a = a if major_a.is_k_major() else a.T.contiguous().T + b = b if major_b.is_k_major() else b.T.contiguous().T + return a, b, c, d, ref_d + + a_fp8, b_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0), per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0) + a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1]) + b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1]) + return a_fp8, b_fp8, c, d, ref_d + + +def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int, + major_a: MajorTypeAB, major_b: MajorTypeAB, + use_ue8m0: bool = False, use_bf16: bool = False): + actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] + aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms] + m = sum(aligned_ms) + + a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + m_indices = torch.empty(m, device='cuda', dtype=torch.int32) + d = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + ref_d = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) + + start = 0 + for i, (actual_m, aligned_m) in enumerate(zip(actual_ms, aligned_ms)): + actual_end = start + actual_m + aligned_end = start + aligned_m + m_indices[start:actual_end] = i + m_indices[actual_end:aligned_end] = -1 + ref_d[start:aligned_end] = a[start:aligned_end] @ b[i].t() + start = aligned_end + ref_d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_d), ref_d) + + if use_bf16: + b = b if major_b.is_k_major() else b.mT.contiguous().mT + return m, a, b, m_indices, d, ref_d + + assert major_a.is_k_major() + a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0) + b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), + torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float)) + for i in range(num_groups): + b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) + b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].mT.contiguous().mT, b_fp8[1]) + return m, a_fp8, b_fp8, m_indices, d, ref_d + + +def generate_m_grouped_contiguous_per_tensor(num_groups: int, expected_m_per_group: int, n: int, k: int, + major_a: MajorTypeAB, major_b: MajorTypeAB, + use_ue8m0: bool = False, use_bf16: bool = False): + actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] + aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms] + m = sum(aligned_ms) + + a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + m_indices = torch.empty(m, device='cuda', dtype=torch.int32) + d = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + ref_d = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) + + start = 0 + for i, (actual_m, aligned_m) in enumerate(zip(actual_ms, aligned_ms)): + actual_end = start + actual_m + aligned_end = start + aligned_m + m_indices[start:actual_end] = i + m_indices[actual_end:aligned_end] = -1 + ref_d[start:aligned_end] = a[start:aligned_end] @ b[i].t() + start = aligned_end + ref_d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_d), ref_d) + + if use_bf16: + b = b if major_b.is_k_major() else b.mT.contiguous().mT + return m, a, b, m_indices, d, ref_d + + assert major_a.is_k_major() + a_fp8 = per_tensor_cast_to_fp8(a, use_ue8m0=use_ue8m0) + b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), + torch.empty((num_groups,), device='cuda', dtype=torch.float)) + for i in range(num_groups): + b_fp8[0][i], b_fp8[1][i] = per_tensor_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) + b_fp8[1][i] = b_fp8[1][i] * a_fp8[1] + a_fp8 = (a_fp8[0], torch.tensor(1.0, device='cuda', dtype=torch.float).view(1,)) + + b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].mT.contiguous().mT, b_fp8[1]) + return m, a_fp8, b_fp8, m_indices, d, ref_d + + +def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, + use_ue8m0: bool = False, use_bf16: bool = False): + a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) + ref_d = torch.einsum('gmk,gnk->gmn', a, b) + + masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) + for j in range(num_groups): + masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) + assert masked_m.amax().item() <= max_m + + if use_bf16: + return a, b, masked_m, d, ref_d + + a_fp8 = (torch.empty_like(a, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, ceil_div(k, 128)), device='cuda', dtype=torch.float)) + b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float)) + for i in range(num_groups): + a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0) + b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) + + return a_fp8, b_fp8, masked_m, d, ref_d + + +def generate_m_grouped_masked_per_tensor(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, + use_ue8m0: bool = False, use_bf16: bool = False): + a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) + ref_d = torch.einsum('gmk,gnk->gmn', a, b) + + masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) + for j in range(num_groups): + masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) + assert masked_m.amax().item() <= max_m + + if use_bf16: + return a, b, masked_m, d, ref_d + + a_fp8 = [torch.empty_like(a, dtype=torch.float8_e4m3fn), torch.empty((1,), device='cuda', dtype=torch.float)] + a_view = a.view(-1) + a_amax = a_view.abs().float().amax().clamp(1e-4) + a_fp8[0] = (a * (448.0 / a_amax)).to(torch.float8_e4m3fn) + a_fp8[1] = (a_amax / 448.0).view(1,) + + b_fp8 = [torch.empty_like(b, dtype=torch.float8_e4m3fn), torch.empty((num_groups,), device='cuda', dtype=torch.float)] + for i in range(num_groups): + b_fp8[0][i], b_fp8[1][i] = per_tensor_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) + b_fp8[1][i] = a_fp8[1] * b_fp8[1][i] + + a_fp8 = (a_fp8[0], torch.tensor(1.0, device='cuda', dtype=torch.float).view(1,)) + return a_fp8, b_fp8, masked_m, d, ref_d + + +def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int], use_ue8m0: bool): + assert get_mk_alignment_for_contiguous_layout() % 128 == 0 + k = sum(ks) + + a = torch.randn((k, m), device='cuda', dtype=torch.bfloat16) + b = torch.randn((k, n), device='cuda', dtype=torch.bfloat16) + c = torch.randn((num_groups, m, n), device='cuda', dtype=torch.float) * 32 + d = c + ref_d = torch.empty_like(c) + + start = 0 + for i, group_k in enumerate(ks): + end = start + group_k + ref_d[i] = c[i] + (a[start:end].T @ b[start:end]) + start = end + + a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0) + b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0) + return k, a_fp8, b_fp8, c, d, ref_d diff --git a/tests/test_bf16.py b/tests/test_bf16.py new file mode 100644 index 000000000..790f700af --- /dev/null +++ b/tests/test_bf16.py @@ -0,0 +1,125 @@ +import torch +import random + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes +) +from generators import ( + enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, generate_normal, + generate_m_grouped_contiguous, generate_m_grouped_masked +) + + +def test_gemm() -> None: + print('Testing GEMM:') + for _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(use_bf16=True): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + out_opt = 'FP32' if out_dtype == torch.float else 'BF16' + acc_opt = f'acc={int(accumulate)}' + + for test_alias in (False, True): + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_bf16=True) + func_name = f'bf16_gemm_{major_opt.lower() if test_alias else "nt"}' + if test_alias: + a = a if major_a.is_k_major() else a.T + b = b if major_b.is_k_major() else b.T + assert a.is_contiguous() and b.is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, c=c) + diff = calc_diff(d, ref_d) + assert diff < 0.0001, (f'{m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=}, ' + f'{diff:.5f}, alias={test_alias}') + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_bf16=True) + + cublas_t = 0 + t = bench_kineto(lambda: deep_gemm.bf16_gemm_nt(a, b, d, c=c), 'bf16_gemm', suppress_kineto_output=True) + if accumulate == 0 and out_dtype == torch.bfloat16: + # noinspection PyBroadException + try: + cublas_t = bench_kineto(lambda: a @ b.T, 'nvjet', suppress_kineto_output=True) + except Exception: + pass + print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | ' + f'{cublas_t / t:.2f}x cuBLAS') + print() + + +def test_m_grouped_gemm_contiguous() -> None: + print('Testing m-grouped contiguous GEMM:') + + for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(use_bf16=True): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + + for test_alias in (False, True): + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True) + func_name = f"m_grouped_bf16_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous" + if test_alias: + assert major_a.is_k_major() + b = b if major_b.is_k_major() else b.mT + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, m_indices) + d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d) + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, m_indices) + + t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_masked() -> None: + print('Testing m-grouped masked GEMM:') + + # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. + for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): + # Test correctness + for i in range(10): + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True) + deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group) + for j in range(num_groups): + diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' + + # Construct full cases + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group) + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_gemm() + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_masked() diff --git a/tests/test_core.py b/tests/test_core.py deleted file mode 100644 index 3b88539cc..000000000 --- a/tests/test_core.py +++ /dev/null @@ -1,312 +0,0 @@ -# PyTorch has its own NVRTC, which may have a lower version than the system -# So try to disable PyTorch's NVRTC, or import NVRTC before PyTorch -import cuda.bindings.nvrtc as nvrtc -print(f'NVRTC version: {nvrtc.nvrtcVersion()[1:]}') - -import random -import torch -from typing import List, Tuple - -import deep_gemm -from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor -from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout - - -def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - pad_size = (128 - (n % 128)) % 128 - x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) - return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) - - -def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - - -def construct(m: int, k: int, n: int) -> \ - Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: - x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) - out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) - ref_out = x @ y.t() - - x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) - # Transpose earlier so that the testing will not trigger transposing kernels - x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) - return x_fp8, y_fp8, out, ref_out - - -def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \ - Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: - alignment = get_m_alignment_for_contiguous_layout() - group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] - m = sum([ceil_div(x, alignment) * alignment for x in group_ms]) - - x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) - m_indices = torch.empty(m, device='cuda', dtype=torch.int32) - out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) - ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) - - start = 0 - for i, group_m in enumerate(group_ms): - actual_end = start + group_m - aligned_end = start + ceil_div(group_m, alignment) * alignment - m_indices[start:actual_end] = i - m_indices[actual_end:aligned_end] = -1 - ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t() - start = aligned_end - ref_out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out) - - assert m % 4 == 0, f'TMA alignment error: {m}' - x_fp8 = per_token_cast_to_fp8(x) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) - for i in range(num_groups): - y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) - - return m, x_fp8, y_fp8, m_indices, out, ref_out - - -def construct_masked_grouped(num_groups: int, max_m: int, expected_m_per_group: int, k: int, n: int) -> \ - Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: - x = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) - y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) - out = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) - ref_out = torch.einsum('gmk,gnk->gmn', x, y) - - assert max_m % 4 == 0, f'TMA alignment error: {max_m}' - x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, k // 128), device='cuda', dtype=torch.float)) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) - for i in range(num_groups): - x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) - y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) - - # Transpose earlier so that the testing will not trigger transposing kernels - x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) - - # Construct mask - masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) - for j in range(num_groups): - masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) - assert masked_m.amax().item() <= max_m - return x_fp8, y_fp8, masked_m, out, ref_out - - -def construct_wgrad(m: int, k: int, n: int) -> \ - Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: - x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) - residual = torch.randn((m, n), device='cuda', dtype=torch.float) * 10 - out = residual.clone() - ref_out = residual + (x.float() @ y.float().t()) - - x_fp8 = per_token_cast_to_fp8(x) - y_fp8 = per_token_cast_to_fp8(y) - - # NOTES: please do inplace add on the `out` later - return x_fp8, y_fp8, residual, out, ref_out - - -def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \ - Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, List[int]]: - num_groups, total_k = len(k_sizes), sum(k_sizes) - - x_flat = torch.empty((m * total_k,), device='cuda', dtype=torch.bfloat16) - y_flat = torch.empty((n * total_k,), device='cuda', dtype=torch.bfloat16) - out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float) - ref_out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float) - - # Fill tensors with data and compute reference output - x_offset, y_offset = 0, 0 - for idx, k in enumerate(k_sizes): - x_chunk = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - y_chunk = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) - - x_flat[x_offset:x_offset + m * k].copy_(x_chunk.flatten()) - y_flat[y_offset:y_offset + n * k].copy_(y_chunk.flatten()) - ref_out[idx] = x_chunk.float() @ y_chunk.float().t() - - x_offset += m * k - y_offset += n * k - - x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn) - y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn) - - total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes) - x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float) - y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float) - - # Cast to FP8 and prepare scale factors - x_offset, y_offset, scale_offset = 0, 0, 0 - for k in k_sizes: - x_fp8_chunk, x_scale_chunk = per_token_cast_to_fp8(x_flat[x_offset:x_offset + m * k].view(m, k)) - y_fp8_chunk, y_scale_chunk = per_token_cast_to_fp8(y_flat[y_offset:y_offset + n * k].view(n, k)) - - x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten()) - y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten()) - - num_scales = ceil_div(k, 128) - x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T) - y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T) - - x_offset += m * k - y_offset += n * k - scale_offset += num_scales - - return (x_fp8_flat, x_scales), (y_fp8_flat, y_scales), out, ref_out, k_sizes - - -def test_gemm() -> None: - print('Testing GEMM:') - for m in (64, 128, 4096): - for k, n in [(576, 7168), (7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: - x_fp8, y_fp8, out, ref_out = construct(m, k, n) - deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) - diff = calc_diff(out, ref_out) - assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) - - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') - print() - - -def test_m_grouped_gemm_contiguous() -> None: - print('Testing grouped contiguous GEMM:') - - for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), - (8, 4096, 7168, 4096), (8, 4096, 2048, 7168), - (32, 256, 7168, 4096), (32, 256, 2048, 7168)): - # NOTES: we should mask the unfilled part before calculating difference - m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) - out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), out) - diff = calc_diff(out, ref_out) - assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) - - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - valid_m = (m_indices != -1).sum().item() - print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') - print() - - -def test_m_grouped_gemm_masked() -> None: - print('Testing grouped masked GEMM:') - - for num_groups, expected_m_per_group in ((1, 1024), (2, 512), (4, 256)): - for k, n in ((7168, 4096), (2048, 7168), ): - # Test correctness - for i in range(10): - x_fp8, y_fp8, masked_m, out, ref_out = construct_masked_grouped(num_groups, 4096, expected_m_per_group, k, n) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group) - for j in range(num_groups): - diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()]) - assert diff < 0.001, f'{expected_m_per_group=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group) - - # Test performance with fixed shapes - # noinspection PyUnboundLocalVariable - valid_m = masked_m.sum().item() - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') - print() - - -def test_wgrad_gemm(): - print('Testing weight gradient GEMM:') - - for k in (4096, 8192): - for m, n in ((7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)): - # Test correctness - x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n) - deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out) - diff = calc_diff(out, ref_out) - assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' - - # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) - x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n) - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out) - - t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True) - print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') - print() - - -def test_k_grouped_wgrad_gemm(): - print('Testing grouped weight gradient GEMM:') - - for num_groups, base_k in ((4, 4096), (4, 8192), (8, 4096)): - for m, n in ((7168, 4096), (2048, 7168)): - # Vary k sizes around base_k - k_sizes = [base_k + random.randint(-1, 1) * 128 for _ in range(num_groups - 1)] - k_sizes.append(base_k * num_groups - sum(k_sizes)) - - # Test correctness - x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes) - deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) - - for idx in range(num_groups): - diff = calc_diff(out[idx], ref_out[idx]) - assert diff < 0.001, f'{num_groups=}, {m=}, {n=}, k={k_sizes[idx]}, batch={idx}, {diff:.5f}' - - # Construct new tensors to avoid L2 cache acceleration - x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes) - total_k = sum(k_sizes) - - def test_func(): - deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) - - t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, with_multiple_kernels=True) * num_groups - print(f' > Performance ({num_groups=}, m={m:5}, n={n:5}, avg_k={total_k//num_groups:5}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * num_groups * m * n * (total_k/num_groups) / t / 1e12:4.0f} TFLOPS, ' - f'{(m * total_k + n * total_k + num_groups * m * n * 2) / 1e9 / t:4.0f} GB/s') - print() - - -if __name__ == '__main__': - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - torch.manual_seed(0) - random.seed(0) - - print('Library path:') - print(f' > {deep_gemm.__path__}\n') - - test_gemm() - test_m_grouped_gemm_contiguous() - test_m_grouped_gemm_masked() - - test_wgrad_gemm() - test_k_grouped_wgrad_gemm() diff --git a/tests/test_fp8.py b/tests/test_fp8.py new file mode 100644 index 000000000..1bd8d4637 --- /dev/null +++ b/tests/test_fp8.py @@ -0,0 +1,246 @@ +import copy +import random +import time +import torch + +import deep_gemm +from deep_gemm.testing import ( + bench, bench_kineto, + calc_diff, count_bytes +) + +from generators import ( + KernelType, get_ue8m0_usage, + enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, + generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous, + generate_m_grouped_contiguous_per_tensor, generate_m_grouped_masked_per_tensor +) + + +def test_gemm() -> None: + print('Testing GEMM:') + for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + out_opt = 'FP32' if out_dtype == torch.float else 'BF16' + acc_opt = f'acc={int(accumulate)}' + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + + for test_alias in (False, True): + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_ue8m0=use_ue8m0) + func_name = f'fp8_gemm_{major_opt.lower() if test_alias else "nt"}' + if test_alias: + a = a if major_a.is_k_major() else (a[0].T, a[1].T) + b = b if major_b.is_k_major() else (b[0].T, b[1].T) + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast) + diff = calc_diff(d, ref_d) + assert diff < 0.001, (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, ' + f'{diff:.5f}, alias={test_alias}') + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_ue8m0=use_ue8m0) + + # Test launch overhead + launch_start_t = time.time_ns() + deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast) + launch_end_t = time.time_ns() + torch.cuda.synchronize() + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'launch {(launch_end_t - launch_start_t) / 1e3:4.0f} us | {t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_contiguous() -> None: + print('Testing m-grouped contiguous GEMM:') + + for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + + for test_alias in (False, True): + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0) + func_name = f"m_grouped_fp8_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous" + if test_alias: + assert major_a.is_k_major() + b = b if major_b.is_k_major() else (b[0].mT, b[1].mT) + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast) + d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d) + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_contiguous_per_tensor() -> None: + print('Testing m-grouped contiguous per tensor GEMM:') + + for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + + for test_alias in (False, True): + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous_per_tensor(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0) + func_name = f"m_grouped_fp8_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous_per_tensor" + if test_alias: + assert major_a.is_k_major() + b = b if major_b.is_k_major() else (b[0].mT, b[1].mT) + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast) + d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d) + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous_per_tensor(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_fp8_gemm_nt_contiguous_per_tensor(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_masked() -> None: + print('Testing m-grouped masked GEMM:') + + # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. + for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + + # Test correctness + for i in range(10): + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + for j in range(num_groups): + diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) + assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' + + # Construct full cases + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_masked_per_tensor() -> None: + print('Testing m-grouped masked per tensor GEMM:') + + # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. + for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + + # Test correctness + for i in range(10): + a, b, masked_m, d, ref_d = generate_m_grouped_masked_per_tensor(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + deep_gemm.m_grouped_fp8_gemm_nt_masked_per_tensor(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + for j in range(num_groups): + diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) + assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' + + # Construct full cases + a, b, masked_m, d, ref_d = generate_m_grouped_masked_per_tensor(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_fp8_gemm_nt_masked_per_tensor(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') + print() + + +def test_k_grouped_gemm_contiguous() -> None: + print('Testing k-grouped contiguous GEMM:') + + for num_groups, m, n, ks, expected_k_per_group in enumerate_k_grouped_contiguous(): + use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D) + + for test_empty_groups in (False, True): + new_ks = copy.deepcopy(ks) + if test_empty_groups: + new_ks[random.randint(0, num_groups - 1)] = 0 + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, new_ks, use_ue8m0=use_ue8m0) + new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda') + deep_gemm.k_grouped_fp8_gemm_tn_contiguous(a, b, d, new_ks, new_ks_tensor, c=c) + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {i=}, {diff:.5f}' + + # Test performance + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, ks, use_ue8m0=use_ue8m0) + ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=c) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_gemm() + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_contiguous_per_tensor() + test_m_grouped_gemm_masked() + test_m_grouped_gemm_masked_per_tensor() + test_k_grouped_gemm_contiguous() diff --git a/tests/test_jit.py b/tests/test_jit.py deleted file mode 100644 index 26b7b36c3..000000000 --- a/tests/test_jit.py +++ /dev/null @@ -1,98 +0,0 @@ -import ctypes -import os -import torch -import cuda.bindings.driver as cbd -from typing import Any, Dict - -from deep_gemm import jit - -# Essential debugging staffs -os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1') -os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1') - - -class VectorAddRuntime(jit.Runtime): - def __init__(self, path: str) -> None: - super().__init__(path) - - @staticmethod - def generate(kwargs: Dict[str, Any]) -> str: - return f""" -#ifdef __CUDACC_RTC__ -#include -#else -#include -#endif - -#include -#include - -template -__global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{ - uint32_t i = blockDim.x * blockIdx.x + threadIdx.x; - if (i < n) {{ - c[i] = a[i] + b[i]; - }} -}} - -static void __instantiate_kernel() {{ - auto ptr = reinterpret_cast(&vector_add<{kwargs['T']}>); -}} -""" - - # noinspection PyShadowingNames,PyMethodOverriding - @staticmethod - def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: - assert kwargs['A'].shape == kwargs['B'].shape == kwargs['C'].shape - assert kwargs['A'].device == kwargs['B'].device == kwargs['C'].device - assert kwargs['A'].dim() == 1 - - config = cbd.CUlaunchConfig() - config.gridDimX = (kwargs['A'].numel() + 127) // 128 - config.gridDimY = 1 - config.gridDimZ = 1 - config.blockDimX = 128 - config.blockDimY = 1 - config.blockDimZ = 1 - config.hStream = kwargs['STREAM'] - - arg_values = ( - kwargs['A'].data_ptr(), - kwargs['B'].data_ptr(), - kwargs['C'].data_ptr(), - kwargs['A'].numel(), - ) - arg_types = ( - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_uint32, - ) - - return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0] - - -if __name__ == '__main__': - print('Generated code:') - kwargs = {'T': 'float'} - code = VectorAddRuntime.generate(kwargs) - print(code) - print() - - for compiler_name in ('NVCC', 'NVRTC'): - # Get compiler - compiler_cls = getattr(jit, f'{compiler_name}Compiler') - print(f'Compiler: {compiler_name}, version: {compiler_cls.__version__()}') - - # Build - print('Building ...') - func = compiler_cls.build('test_func', code, VectorAddRuntime, kwargs) - - # Run and check - a = torch.randn((1024, ), dtype=torch.float32, device='cuda') - b = torch.randn((1024, ), dtype=torch.float32, device='cuda') - c = torch.empty_like(a) - ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream) - assert ret == cbd.CUresult.CUDA_SUCCESS, ret - torch.testing.assert_close(c, a + b) - print(f'JIT test for {compiler_name} passed\n') diff --git a/tests/test_layout.py b/tests/test_layout.py new file mode 100644 index 000000000..42d7208bb --- /dev/null +++ b/tests/test_layout.py @@ -0,0 +1,116 @@ +import time +import torch +import random +from deep_gemm.testing import bench_kineto, count_bytes, calc_diff +from deep_gemm.utils import ( + align, ceil_div, + per_token_cast_to_fp8, per_channel_cast_to_fp8, + get_tma_aligned_size, + get_mn_major_tma_aligned_tensor, + get_mn_major_tma_aligned_packed_ue8m0_tensor, + get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor +) + +from generators import ( + enumerate_transpose, + enumerate_sf_layout, + enumerate_k_grouped_sf_layout +) + + +def get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.float and x.dim() in (2, 3) + + # First, convert into UE8M0 `uint8_t` + ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8) + + # Second, make padded packed tensors + mn, k = x.shape[-2], x.shape[-1] + remove_dim = False + if x.dim() == 2: + x, remove_dim = x.unsqueeze(0), True + b = x.shape[0] + aligned_mn = get_tma_aligned_size(mn, 4) + aligned_k = align(k, 4) + padded = torch.zeros((b, aligned_mn, aligned_k), device=x.device, dtype=torch.uint8) + padded[:, :mn, :k] = ue8m0_tensor + padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn, aligned_k // 4) + + # Finally, transpose + transposed = torch.zeros((b, aligned_k // 4, aligned_mn), device=x.device, dtype=torch.int).mT + transposed[:, :, :] = padded + aligned_x = transposed[:, :mn, :] + return aligned_x.squeeze(0) if remove_dim else aligned_x + + +def test_sf_layout_kernels() -> None: + print('Testing SF layout kernels:') + for mn, k, with_transpose, use_ue8m0, num_groups in enumerate_sf_layout(): + x = torch.randn((num_groups * mn, k), dtype=torch.bfloat16, device='cuda') + x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0) + fp32_sf = fp32_sf if num_groups == 1 else fp32_sf.view(num_groups, mn, -1) + fp32_sf = fp32_sf if with_transpose else fp32_sf.transpose(-1, -2).contiguous().transpose(-1, -2) + + # Correctness + if use_ue8m0: + impl, name = get_mn_major_tma_aligned_packed_ue8m0_tensor, 'pack_fp32_into_ue8m0' + packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf) + ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(fp32_sf) + assert torch.equal(packed_sf, ref_packed_sf), f'{mn=}, {k=}, {with_transpose=}, {num_groups=}' + assert packed_sf.shape == ref_packed_sf.shape + assert all([packed_sf.stride(i) == ref_packed_sf.stride(i) for i in range(packed_sf.dim())]) + else: + impl, name = get_mn_major_tma_aligned_tensor, 'transpose' + transposed_sf = get_mn_major_tma_aligned_tensor(fp32_sf) + tma_aligned_mn, sf_k = get_tma_aligned_size(mn, fp32_sf.element_size()), ceil_div(k, 128) + if num_groups > 1: + assert transposed_sf.size(0) == num_groups + assert transposed_sf.stride(0) == tma_aligned_mn * sf_k + assert transposed_sf.shape[-2:] == (mn, sf_k) + assert transposed_sf.stride()[-2:] == (1, tma_aligned_mn) + assert torch.equal(fp32_sf, transposed_sf) + + # Performance + try: + t = bench_kineto(lambda: impl(fp32_sf), name) + except AssertionError as e: + # Some cases may fallback to PyTorch impl + t = 0 + print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}, use_ue8m0={int(use_ue8m0)}): ' + f'{t * 1e6:4.0f} us | {count_bytes(fp32_sf, impl(fp32_sf)) / 1e9 / t if t else 0:4.0f} GB/s') + print() + + +def test_k_grouped_sf_layout_kernels() -> None: + print('Testing k-grouped SF layout kernels:') + for mn, ks, num_groups in enumerate_k_grouped_sf_layout(): + sf_ks = [k // 128 for k in ks] + packed_sf_ks = [ceil_div(k, 512) for k in ks] + ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') + x = torch.randn((sum(ks), mn), dtype=torch.bfloat16, device='cuda') + x, fp32_sf = per_channel_cast_to_fp8(x, use_ue8m0=True) + + # Correctness + packed_sf = get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf, ks_tensor, ks) + split_packed_sf = packed_sf.split(packed_sf_ks) + split_fp32_sf = fp32_sf.split(sf_ks) + for i in range(num_groups): + ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(split_fp32_sf[i].T).T + assert torch.equal(split_packed_sf[i], ref_packed_sf), f'{i=}' + + # Performance + t = bench_kineto(lambda: get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf, ks_tensor, ks), 'pack_fp32_into_ue8m0') + print(f' > Perf ({num_groups=:3}, {mn=:5}, sum_k={sum(ks):5}):' + f'{t * 1e6:4.0f} us | ' + f'{count_bytes(fp32_sf, packed_sf, ks_tensor) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.manual_seed(1) + random.seed(1) + + test_sf_layout_kernels() + test_k_grouped_sf_layout_kernels() diff --git a/tests/test_lazy_init.py b/tests/test_lazy_init.py new file mode 100644 index 000000000..5363b6db3 --- /dev/null +++ b/tests/test_lazy_init.py @@ -0,0 +1,15 @@ +import torch +import torch.multiprocessing as mp +import deep_gemm + + +def main(local_rank: int): + torch.cuda.set_device(local_rank) + + +if __name__ == '__main__': + procs = [mp.Process(target=main, args=(i, ), ) for i in range(8)] + for p in procs: + p.start() + for p in procs: + p.join() diff --git a/third-party/cutlass b/third-party/cutlass index eefa17131..b244379d9 160000 --- a/third-party/cutlass +++ b/third-party/cutlass @@ -1 +1 @@ -Subproject commit eefa171318b79cbe2e78514d4cce5cd0fe919d0c +Subproject commit b244379d9b15574e07b73b814b88bd2233f0b3ce diff --git a/third-party/fmt b/third-party/fmt new file mode 160000 index 000000000..553ec11ec --- /dev/null +++ b/third-party/fmt @@ -0,0 +1 @@ +Subproject commit 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28