Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0c88cd0
Fix illegal memory address when skipping -1 m indices (#113)
shixianc Jun 16, 2025
ac428e2
Fixed the bug in get_swizzle_mode function related to elem_size setti…
Ther-LF Jun 23, 2025
e82c413
Revert "Fixed the bug in get_swizzle_mode function related to elem_si…
Jun 23, 2025
3fc6728
[add] fix smem_barrier size in wgrad way (#122)
fy1214 Jul 2, 2025
03d0be3
Simplify expression
LyricZhao Jul 2, 2025
9da4a23
Add more GPU architectures support (#112)
RayWang96 Jul 18, 2025
6c9558e
Update CUDA toolkits requirement (#128)
RayWang96 Jul 18, 2025
4ca3cdf
fix: update .gitmodules (#130)
zhyncs Jul 20, 2025
c1db17e
Updated submodules to use https:// vs git@ (#129)
smarterclayton Jul 21, 2025
436a563
Use std::filesystem::directory_iterator instead of std::filesystem::r…
RayWang96 Jul 21, 2025
1876566
Code lint
LyricZhao Jul 21, 2025
8987798
Update setup.py (#134)
danthe3rd Jul 28, 2025
4b4e4f2
Update system.hpp (#133)
danthe3rd Jul 28, 2025
dd6ed14
Add torch as build dependency. (#139)
yuxianq Jul 28, 2025
fb7c687
Merge pull request #135 from danthe3rd/patch-3
danthe3rd Jul 29, 2025
a581263
Fix indent
LJC00118 Jul 29, 2025
6bc75b5
Fix smxx layout assertion (#141)
LJC00118 Jul 30, 2025
c50deed
Code lint
LyricZhao Jul 30, 2025
aff9da0
Fix SM90 GEMM (#149)
yukuai26 Aug 1, 2025
d9c363f
Make various updates and fixes:
RayWang96 Aug 3, 2025
3979c05
Merge pull request #151 from RayWang96/update_jit
RayWang96 Aug 3, 2025
7b6b556
Fix smxx layout assertion (#154)
LJC00118 Aug 5, 2025
6d3717d
Update test_fp8.py (#159)
fzyzcjy Aug 14, 2025
3254b75
Polish get_best_configs modeling. (#158)
Insideyyy Aug 14, 2025
f85ec64
Make various updates and fixes: (#164)
RayWang96 Aug 15, 2025
affdb1c
Add sm_100f support and make nvcc 13 happy (#157)
VALLIS-NERIA Aug 22, 2025
f20256f
Compatible with CUDA 13
LyricZhao Aug 22, 2025
e38c2e3
Remove comments
LyricZhao Aug 22, 2025
2da871e
Fix grouped gemms performance issue. (#168)
Insideyyy Aug 22, 2025
89b4089
Update test files in README documentation (#169)
skyloevil Aug 25, 2025
61a7f9c
refactor: rebase to deepgemm 2.0.0+
KyeeHuang Nov 10, 2025
e2dfa4e
fix: pass per tensor kernel compile and benchmark
KyeeHuang Nov 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "third-party/cutlass"]
path = third-party/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "third-party/fmt"]
path = third-party/fmt
url = https://github.com/fmtlib/fmt.git
55 changes: 22 additions & 33 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,44 +1,33 @@
# NOTES: current just for CMake-based IDE (e.g. CLion) indexing, the real compilation is done via JIT
# TODO: add CUDA utils' library via CMake
cmake_minimum_required(VERSION 3.10)
project(deep_gemm LANGUAGES CXX CUDA)

set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CUDA_STANDARD 20)
set(CMAKE_VERBOSE_MAKEFILE ON)

find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED)

file(WRITE ${CMAKE_BINARY_DIR}/test_cuda.cu "extern \"C\" __global__ void testKernel() { }")
execute_process(
COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu
RESULT_VARIABLE NVCC_RESULT
OUTPUT_VARIABLE NVCC_OUTPUT
ERROR_VARIABLE NVCC_ERROR_OUTPUT
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi")
set(CUDA_SEPARABLE_COMPILATION ON)
list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG")
list(APPEND CUDA_NVCC_FLAGS "-O3")
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")

if (NVCC_RESULT EQUAL "0")
set(NVCC_SUPPORTS_SM90 TRUE)
message(STATUS "NVCC supports SM90")
else()
message(STATUS "NVCC does not support SM90")
endif()
set(USE_SYSTEM_NVTX on)
set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile")
set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")

if (NVCC_SUPPORTS_SM90)
set(TORCH_CUDA_ARCH_LIST "8.6" CACHE STRING "Add arch tag 90a to NVCC" FORCE)
list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
endif()
find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED)
find_package(Torch REQUIRED)

include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)

include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)

set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -fPIC -DNDEBUG")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3 -std=c++17 -DNDEBUG --ptxas-options=--register-usage-level=10")
# The main Python API entrance
pybind11_add_module(deep_gemm_cpp csrc/python_api.cpp)
target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} torch_python cuda)

cuda_add_library(example_gemm STATIC indexing/main.cu)
# Enable kernel code indexing with CMake-based IDEs
cuda_add_library(deep_gemm_indexing_cuda STATIC csrc/indexing/main.cu)
172 changes: 57 additions & 115 deletions README.md

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Change current directory into project root
original_dir=$(pwd)
script_dir=$(realpath "$(dirname "$0")")
cd "$script_dir"

# Remove old dist file, build files, and install
rm -rf build dist
rm -rf *.egg-info
python setup.py bdist_wheel

# Open users' original directory
cd "$original_dir"
576 changes: 576 additions & 0 deletions csrc/apis/gemm.hpp

Large diffs are not rendered by default.

93 changes: 93 additions & 0 deletions csrc/apis/layout.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#pragma once

#include "../utils/layout.hpp"
#include "../jit_kernels/impls/smxx_layout.hpp"

namespace deep_gemm::layout {

static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
const int& mn, const int& k,
const std::tuple<int, int, int>& recipe,
const std::optional<int>& num_groups,
const bool& is_sfa,
const bool& disable_ue8m0_cast,
const bool& is_per_tensor) {
const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe);
const auto& gran_k = std::get<2>(recipe);
const auto& arch_major = device_runtime->get_arch_major();

// Pre-transform checks
if (is_per_tensor)
check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, false, std::nullopt, true);
else
check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, false, std::nullopt, false);

// (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
return get_mn_major_tma_aligned_tensor(sf);

// (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) {
DG_HOST_ASSERT(not disable_ue8m0_cast);
return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf);
}

// (FP32, 128, 128) on SM90: no need to transform, check shape and contiguous
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast) and is_per_tensor == false)
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat);

// (FP32, 128, 128) per tensor on SM90: no need to transform, check shape and contiguous
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast) and is_per_tensor)
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat, true);

// (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) {
DG_HOST_ASSERT(not disable_ue8m0_cast);
const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128));
return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted);
}

// (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10)
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);

DG_HOST_UNREACHABLE("Unknown SF transformation");
}

static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf,
const std::vector<int>& ks,
const torch::Tensor& ks_tensor,
const std::tuple<int, int, int>& recipe) {
DG_HOST_ASSERT(sf.dim() == 2);
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
const auto& arch_major = device_runtime->get_arch_major();

// FP32 on SM90
if (sf.scalar_type() == torch::kFloat and arch_major == 9)
DG_HOST_UNREACHABLE("Unimplemented");

// FP32 on SM100
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks);

// INT on SM100
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
DG_HOST_UNREACHABLE("Unimplemented");

DG_HOST_UNREACHABLE("Unknown cases");
}

static void register_apis(pybind11::module_& m) {
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false,
py::arg("disable_ue8m0_cast") = false, py::arg("is_per_tensor") = false);

m.def("get_tma_aligned_size", &get_tma_aligned_size);
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor);
m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor);
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
}

} // namespace deep_gemm::layout
28 changes: 28 additions & 0 deletions csrc/apis/runtime.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

#include "../jit/compiler.hpp"
#include "../jit/device_runtime.hpp"

namespace deep_gemm::runtime {

static void register_apis(pybind11::module_& m) {
m.def("set_num_sms", [&](const int& new_num_sms) {
device_runtime->set_num_sms(new_num_sms);
});
m.def("get_num_sms", [&]() {
return device_runtime->get_num_sms();
});
m.def("set_tc_util", [&](const int& new_tc_util) {
device_runtime->set_tc_util(new_tc_util);
});
m.def("get_tc_util", [&]() {
return device_runtime->get_tc_util();
});

m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) {
Compiler::prepare_init(library_root_path, cuda_home_path_by_python);
KernelRuntime::prepare_init(cuda_home_path_by_python);
});
}

} // namespace deep_gemm::runtime
13 changes: 13 additions & 0 deletions csrc/indexing/main.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
#include <deep_gemm/impls/smxx_layout.cuh>

using namespace deep_gemm;

int main() {
return 0;
}
31 changes: 31 additions & 0 deletions csrc/jit/cache.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include <filesystem>
#include <memory>
#include <unordered_map>

#include "kernel_runtime.hpp"

namespace deep_gemm {

class KernelRuntimeCache {
std::unordered_map<std::string, std::shared_ptr<KernelRuntime>> cache;

public:
// TODO: consider cache capacity
KernelRuntimeCache() = default;

std::shared_ptr<KernelRuntime> get(const std::filesystem::path& dir_path) {
// Hit the runtime cache
if (const auto& iterator = cache.find(dir_path); iterator != cache.end())
return iterator->second;

if (KernelRuntime::check_validity(dir_path))
return cache[dir_path] = std::make_shared<KernelRuntime>(dir_path);
return nullptr;
}
};

static auto kernel_runtime_cache = std::make_shared<KernelRuntimeCache>();

} // namespace deep_gemm
Loading