diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index fb205de2..ccddc050 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -52,7 +52,8 @@ static void fp8_gemm_nt(const std::pair& a, const std::optional& c, std::optional> recipe, const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { + const bool& disable_ue8m0_cast, + const std::optional& bias = std::nullopt) { // Shape must be `[M, K] @ [N, K].T` const auto& major_a = get_major_type_ab(a.first); const auto& major_b = get_major_type_ab(b.first); @@ -94,7 +95,7 @@ static void fp8_gemm_nt(const std::pair& a, sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims); } } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { - sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims, std::nullopt, bias); } else { DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } @@ -106,9 +107,10 @@ static void fp8_gemm_nn(const std::pair& a, const std::optional& c, const std::optional>& recipe, const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { + const bool& disable_ue8m0_cast, + const std::optional& bias = std::nullopt) { fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)}, - d, c, recipe, compiled_dims, disable_ue8m0_cast); + d, c, recipe, compiled_dims, disable_ue8m0_cast, bias); } static void fp8_gemm_tn(const std::pair& a, @@ -117,10 +119,11 @@ static void fp8_gemm_tn(const std::pair& a, const std::optional& c, const std::optional>& recipe, const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { + const bool& disable_ue8m0_cast, + const std::optional& bias = std::nullopt) { fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, {b.first.transpose(0, 1), b.second.transpose(0, 1)}, - d, c, recipe, compiled_dims, disable_ue8m0_cast); + d, c, recipe, compiled_dims, disable_ue8m0_cast, bias); } static void fp8_gemm_tt(const std::pair& a, @@ -129,9 +132,10 @@ static void fp8_gemm_tt(const std::pair& a, const std::optional& c, const std::optional>& recipe, const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { + const bool& disable_ue8m0_cast, + const std::optional& bias = std::nullopt) { fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b, - d, c, recipe, compiled_dims, disable_ue8m0_cast); + d, c, recipe, compiled_dims, disable_ue8m0_cast, bias); } static void m_grouped_fp8_gemm_nt_contiguous(const std::pair& a, @@ -140,7 +144,8 @@ static void m_grouped_fp8_gemm_nt_contiguous(const std::pair> recipe, const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { + const bool& disable_ue8m0_cast, + const std::optional& bias = std::nullopt) { // Shape must be `[M, K] @ [G, N, K].mT` const auto& major_a = get_major_type_ab(a.first); const auto& major_b = get_major_type_ab(b.first); @@ -182,7 +187,7 @@ static void m_grouped_fp8_gemm_nt_contiguous(const std::pair>& recipe, const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { + const bool& disable_ue8m0_cast, + const std::optional& bias = std::nullopt) { m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)}, - d, m_indices, recipe, compiled_dims, disable_ue8m0_cast); + d, m_indices, recipe, compiled_dims, disable_ue8m0_cast, bias); } static void m_grouped_fp8_gemm_nt_masked(const std::pair& a, @@ -206,7 +212,8 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair> recipe, const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { + const bool& disable_ue8m0_cast, + const std::optional& bias = std::nullopt) { // Shape must be `[G, M, K] @ [G, N, K].mT` const auto& major_a = get_major_type_ab(a.first); const auto& major_b = get_major_type_ab(b.first); @@ -243,7 +250,7 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair& epilogue_type; @@ -31,6 +32,7 @@ class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime); }}; )", @@ -65,7 +67,7 @@ static void __instantiate_kernel() {{ args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype), - get_default_epilogue_type(args.epilogue_type)); + get_default_epilogue_type(args.epilogue_type), args.with_bias); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { @@ -74,7 +76,7 @@ static void __instantiate_kernel() {{ args.grouped_layout, args.m, args.n, args.k, args.tensor_map_a, args.tensor_map_b, args.tensor_map_sfa, args.tensor_map_sfb, - args.tensor_map_cd)); + args.tensor_map_cd, args.bias_ptr)); } }; @@ -85,7 +87,8 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const std::string& compiled_dims, - const std::optional& epilogue_type = std::nullopt) { + const std::optional& epilogue_type = std::nullopt, + const std::optional& bias = std::nullopt) { const auto& aligned_k = align(k, 128); const auto& config = get_best_config( GemmType::Normal, KernelType::Kernel1D1D, @@ -118,6 +121,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa const SM100FP8Gemm1D1DRuntime::Args& args = { .m = m, .n = n, .k = aligned_k, .num_groups = 1, + .with_bias = bias.has_value(), .compiled_dims = compiled_dims, .epilogue_type = epilogue_type, .gemm_config = config, @@ -129,7 +133,8 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + .tensor_map_cd = tensor_map_cd, + .bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr }; const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); @@ -142,7 +147,8 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con const torch::Tensor& m_indices, const int& num_groups, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { + const std::string& compiled_dims, + const std::optional& bias = std::nullopt) { const auto& aligned_k = align(k, 128); const auto& config = get_best_config( GemmType::MGroupedContiguous, KernelType::Kernel1D1D, @@ -176,6 +182,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con const SM100FP8Gemm1D1DRuntime::Args& args = { .m = m, .n = n, .k = aligned_k, .num_groups = num_groups, + .with_bias = bias.has_value(), .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, .gemm_config = config, @@ -187,7 +194,8 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + .tensor_map_cd = tensor_map_cd, + .bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr }; const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code); @@ -201,7 +209,8 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t const int& num_groups, const int& m, const int& n, const int& k, const int& expected_m, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { + const std::string& compiled_dims, + const std::optional& bias = std::nullopt) { const auto& aligned_k = align(k, 128); const auto& config = get_best_config( GemmType::MGroupedMasked, KernelType::Kernel1D1D, @@ -234,6 +243,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t const SM100FP8Gemm1D1DRuntime::Args& args = { .m = m, .n = n, .k = aligned_k, .num_groups = num_groups, + .with_bias = bias.has_value(), .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, .gemm_config = config, @@ -245,7 +255,8 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + .tensor_map_cd = tensor_map_cd, + .bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr }; const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code); @@ -302,6 +313,7 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& const SM100FP8Gemm1D1DRuntime::Args& args = { .m = m, .n = n, .k = sum_k, .num_groups = num_groups, + .with_bias = false, .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, .gemm_config = config, @@ -313,7 +325,8 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + .tensor_map_cd = tensor_map_cd, + .bias_ptr = nullptr }; const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code); @@ -367,6 +380,7 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, const SM100FP8Gemm1D1DRuntime::Args& args = { .m = m, .n = n, .k = k, .num_groups = batch_size, + .with_bias = false, .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, .gemm_config = config, @@ -378,7 +392,8 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + .tensor_map_cd = tensor_map_cd, + .bias_ptr = nullptr }; const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index da7f461c..be436444 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -23,7 +23,7 @@ template + typename epilogue_type_t, bool kWithBias = false> __global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) sm100_fp8_gemm_1d1d_impl(int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, @@ -31,7 +31,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, const __grid_constant__ cute::TmaDescriptor tensor_map_b, const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, - const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { + const __grid_constant__ cute::TmaDescriptor tensor_map_cd, + const cd_dtype_t* __restrict__ bias_ptr) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) using Barrier = cutlass::arch::ClusterTransactionBarrier; using Allocator = cute::conditional_t; @@ -479,6 +480,11 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + const cd_dtype_t* __restrict__ current_bias_ptr = nullptr; + if constexpr (kWithBias) { + current_bias_ptr = bias_ptr + n_idx + i * kNumElemsPerBankGroup; + } // Load from tensor memory, store into shared memory uint32_t values[kNumElemsPerBankGroup]; @@ -488,6 +494,13 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, values[0], values[1], values[2], values[3]); cutlass::arch::fence_view_async_tmem_load(); + if constexpr (kWithBias) { + #pragma unroll + for (int o = 0; o < 4; o++) { + float val = __uint_as_float(values[o]); + values[o] = __float_as_uint(val + current_bias_ptr[o]); + } + } st_shared(smem_ptr, values[0], values[1], values[2], values[3]); } else { // For BF16 output, read, cast and store @@ -496,6 +509,13 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, values[0], values[1], values[2], values[3], values[4], values[5], values[6], values[7]); cutlass::arch::fence_view_async_tmem_load(); + if constexpr (kWithBias) { + #pragma unroll + for (int o = 0; o < 8; o++) { + float val = __uint_as_float(values[o]); + values[o] = __float_as_uint(val + static_cast(current_bias_ptr[o])); + } + } st_shared(smem_ptr, cast_into_bf16_and_pack(values[0], values[1]), cast_into_bf16_and_pack(values[2], values[3]),