Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 35 additions & 21 deletions csrc/apis/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::optional<torch::Tensor>& c,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
const bool& disable_ue8m0_cast,
const std::optional<torch::Tensor>& bias = std::nullopt) {
// Shape must be `[M, K] @ [N, K].T`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
Expand Down Expand Up @@ -94,7 +95,7 @@ static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims);
}
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims, std::nullopt, bias);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
Expand All @@ -106,9 +107,10 @@ static void fp8_gemm_nn(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
const bool& disable_ue8m0_cast,
const std::optional<torch::Tensor>& bias = std::nullopt) {
fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)},
d, c, recipe, compiled_dims, disable_ue8m0_cast);
d, c, recipe, compiled_dims, disable_ue8m0_cast, bias);
}

static void fp8_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
Expand All @@ -117,10 +119,11 @@ static void fp8_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
const bool& disable_ue8m0_cast,
const std::optional<torch::Tensor>& bias = std::nullopt) {
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)},
{b.first.transpose(0, 1), b.second.transpose(0, 1)},
d, c, recipe, compiled_dims, disable_ue8m0_cast);
d, c, recipe, compiled_dims, disable_ue8m0_cast, bias);
}

static void fp8_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
Expand All @@ -129,9 +132,10 @@ static void fp8_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
const bool& disable_ue8m0_cast,
const std::optional<torch::Tensor>& bias = std::nullopt) {
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b,
d, c, recipe, compiled_dims, disable_ue8m0_cast);
d, c, recipe, compiled_dims, disable_ue8m0_cast, bias);
}

static void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
Expand All @@ -140,7 +144,8 @@ static void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
const torch::Tensor& m_indices,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
const bool& disable_ue8m0_cast,
const std::optional<torch::Tensor>& bias = std::nullopt) {
// Shape must be `[M, K] @ [G, N, K].mT`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
Expand Down Expand Up @@ -182,7 +187,7 @@ static void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
num_groups, m, n, k, major_a, major_b, major_sfb, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices,
num_groups, m, n, k, major_a, major_b, compiled_dims);
num_groups, m, n, k, major_a, major_b, compiled_dims, bias);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
Expand All @@ -194,9 +199,10 @@ static void m_grouped_fp8_gemm_nn_contiguous(const std::pair<torch::Tensor, torc
const torch::Tensor& m_indices,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
const bool& disable_ue8m0_cast,
const std::optional<torch::Tensor>& bias = std::nullopt) {
m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)},
d, m_indices, recipe, compiled_dims, disable_ue8m0_cast);
d, m_indices, recipe, compiled_dims, disable_ue8m0_cast, bias);
}

static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
Expand All @@ -206,7 +212,8 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T
const int& expected_m,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
const bool& disable_ue8m0_cast,
const std::optional<torch::Tensor>& bias = std::nullopt) {
// Shape must be `[G, M, K] @ [G, N, K].mT`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
Expand Down Expand Up @@ -243,7 +250,7 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T
num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims, bias);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
Expand Down Expand Up @@ -568,34 +575,41 @@ static void register_apis(pybind11::module_& m) {
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
py::arg("disable_ue8m0_cast") = false,
py::arg("bias") = std::nullopt);
m.def("fp8_gemm_nn", &fp8_gemm_nn,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
py::arg("disable_ue8m0_cast") = false,
py::arg("bias") = std::nullopt);
m.def("fp8_gemm_tn", &fp8_gemm_tn,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "mn",
py::arg("disable_ue8m0_cast") = false);
py::arg("disable_ue8m0_cast") = false,
py::arg("bias") = std::nullopt);
m.def("fp8_gemm_tt", &fp8_gemm_tt,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "mn",
py::arg("disable_ue8m0_cast") = false);
py::arg("disable_ue8m0_cast") = false,
py::arg("bias") = std::nullopt);
m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
py::arg("disable_ue8m0_cast") = false,
py::arg("bias") = std::nullopt);
m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
py::arg("disable_ue8m0_cast") = false,
py::arg("bias") = std::nullopt);
m.def("m_grouped_fp8_gemm_nt_masked", &m_grouped_fp8_gemm_nt_masked,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false,
py::arg("bias") = std::nullopt);
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
Expand Down
37 changes: 26 additions & 11 deletions csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8Gemm1D1DRuntim
public:
struct Args {
int m, n, k, num_groups;
bool with_bias;
const std::string& compiled_dims;
const std::optional<std::string>& epilogue_type;

Expand All @@ -31,6 +32,7 @@ class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8Gemm1D1DRuntim
CUtensorMap tensor_map_sfa;
CUtensorMap tensor_map_sfb;
CUtensorMap tensor_map_cd;
void* bias_ptr;
};

static std::string generate_impl(const Args& args) {
Expand All @@ -51,7 +53,7 @@ static void __instantiate_kernel() {{
{}, {},
{},
{}, {}, {},
{}
{}, {}
>);
}};
)",
Expand All @@ -65,7 +67,7 @@ static void __instantiate_kernel() {{
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms,
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype),
get_default_epilogue_type(args.epilogue_type));
get_default_epilogue_type(args.epilogue_type), args.with_bias);
}

static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
Expand All @@ -74,7 +76,7 @@ static void __instantiate_kernel() {{
args.grouped_layout, args.m, args.n, args.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_sfa, args.tensor_map_sfb,
args.tensor_map_cd));
args.tensor_map_cd, args.bias_ptr));
}
};

Expand All @@ -85,7 +87,8 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims,
const std::optional<std::string>& epilogue_type = std::nullopt) {
const std::optional<std::string>& epilogue_type = std::nullopt,
const std::optional<torch::Tensor>& bias = std::nullopt) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Normal, KernelType::Kernel1D1D,
Expand Down Expand Up @@ -118,6 +121,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.with_bias = bias.has_value(),
.compiled_dims = compiled_dims,
.epilogue_type = epilogue_type,
.gemm_config = config,
Expand All @@ -129,7 +133,8 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
.tensor_map_cd = tensor_map_cd,
.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
Expand All @@ -142,7 +147,8 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
const torch::Tensor& m_indices,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const std::string& compiled_dims,
const std::optional<torch::Tensor>& bias = std::nullopt) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedContiguous, KernelType::Kernel1D1D,
Expand Down Expand Up @@ -176,6 +182,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.with_bias = bias.has_value(),
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
Expand All @@ -187,7 +194,8 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
.tensor_map_cd = tensor_map_cd,
.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code);
Expand All @@ -201,7 +209,8 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const std::string& compiled_dims,
const std::optional<torch::Tensor>& bias = std::nullopt) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedMasked, KernelType::Kernel1D1D,
Expand Down Expand Up @@ -234,6 +243,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.with_bias = bias.has_value(),
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
Expand All @@ -245,7 +255,8 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
.tensor_map_cd = tensor_map_cd,
.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code);
Expand Down Expand Up @@ -302,6 +313,7 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = sum_k,
.num_groups = num_groups,
.with_bias = false,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
Expand All @@ -313,7 +325,8 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
.tensor_map_cd = tensor_map_cd,
.bias_ptr = nullptr
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code);
Expand Down Expand Up @@ -367,6 +380,7 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = batch_size,
.with_bias = false,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
Expand All @@ -378,7 +392,8 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
.tensor_map_cd = tensor_map_cd,
.bias_ptr = nullptr
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
Expand Down
Loading