Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@

#include "run_gemm_quant_example.inc"

// template <typename T>
// using GemmConfig = GemmConfigQuantDecode<T>;

template <typename T>
using GemmConfig = GemmConfigQuantDecode<T>;
using GemmConfig = GemmConfigQuantInterwave<T>;

// GemmConfigQuantPrefill is also supported for aquant grouped quantization
// template <typename T>
Expand Down
36 changes: 36 additions & 0 deletions example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,42 @@ struct GemmConfigQuantDecode : public GemmConfigBase
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
};

template <typename PrecType>
struct GemmConfigQuantIntrawave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
};

template <typename PrecType>
struct GemmConfigQuantInterwave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();

static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};

template <typename PrecType>
struct GemmConfigRowColQuant : public GemmConfigBase
{
Expand Down
10 changes: 5 additions & 5 deletions example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
std::mt19937 gen(rd());
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);

if(init_method == 0)
if(init_method == 0) // uniform distribution
{
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
Expand Down Expand Up @@ -594,7 +594,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
ck_tile::FillUniformDistribution<AQDataType>{1.0f, 1.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
Expand Down Expand Up @@ -627,12 +627,12 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
*bq_tensor_ptr);
}
}
else if(init_method == 1)
else if(init_method == 1) // monotonic initialization
{
std::cout << "Monotonic initialization is not supported." << std::endl;
return 0;
}
else if(init_method == 2)
else if(init_method == 2) // constant initialization
{
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
Expand All @@ -650,7 +650,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
else
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(*aq_tensor_ptr);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(1.0f)}(*aq_tensor_ptr);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);

if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ struct AQuantBlockUniversalGemmAsBsCr
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
CWarpTensor c_warp_tensor;

// for every column in AQ
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
Expand Down Expand Up @@ -322,6 +323,180 @@ struct AQuantBlockUniversalGemmAsBsCr
}
};

template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Interwave, GemmTraits>
{
static constexpr index_t KPerThread = GemmTraits::KPerThread;
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
// static constexpr index_t QuantGroupSizeK = GemmTraits::QuantGroupSize::kK;

// Match the base Interwave loop structure; quantization handling will be reintroduced
// later.
static constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;

// static constexpr index_t KIterPerQScale = GemmTraits::KIterPerQScale;

static constexpr auto ALdsTileDistr =
make_static_tile_distribution(MakeABlockDistributionEncode());
static constexpr auto BLdsTileDistr =
make_static_tile_distribution(MakeBBlockDistributionEncode());

using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));

ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;

template <index_t KIdx,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
constexpr auto a_lds_load_distr = [&]() {
if constexpr(ALoadTranspose)
return make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(MakeABlockDistributionEncode()),
ADataType>::TransposedDstrEncode{});
else
return make_static_tile_distribution(MakeABlockDistributionEncode());
}();
constexpr auto b_lds_load_distr = [&]() {
if constexpr(BLoadTranspose)
return make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(MakeBBlockDistributionEncode()),
BDataType>::TransposedDstrEncode{});
else
return make_static_tile_distribution(MakeBBlockDistributionEncode());
}();
constexpr auto a_lds_shape = []() {
if constexpr(ALoadTranspose)
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::MPerBlock>{});
else
return make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{});
}();
constexpr auto b_lds_shape = []() {
if constexpr(BLoadTranspose)
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::NPerBlock>{});
else
return make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{});
}();
constexpr auto k_idx_offset = KIdx * KPerInnerLoop;
constexpr auto a_offset =
ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
constexpr auto b_offset =
BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};

auto a_lds_gemm_window = make_tile_window(
a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr);
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);

load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_lds_gemm_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_lds_gemm_window);
}

// C += A * B (quantization/scaling paths are intentionally commented for now)
template <typename CBlockTensor,
typename AQBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
[[maybe_unused]] AQBlockTensor& aq_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as corresponding "
"C block tensor data type!");
// constexpr auto warp_size = get_warp_size();

static_for<0, KRepeat, 1>{}([&](auto kIter) {
if(get_thread_id() == 0 && get_block_id() == 0)
{
printf("kIter: %d\n", kIter.value);
}
LocalPrefetch<kIter.value>(a_block_window, b_block_window, a_load_tr, b_load_tr);
__builtin_amdgcn_sched_barrier(0);

if constexpr(kIter.value != 0 || KRepeat == 1)
{
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
}

static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kInnerIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));

static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kInnerIter>{},
b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));

CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));

// Quantization accumulation path (AQPicker, qscale application) to be
// re-enabled later.

if constexpr(kIter.value == KRepeat - 1 &&
kInnerIter.value == KInnerLoopIter - 1 &&
mIter.value == MIterPerWarp - 1 &&
nIter.value == NIterPerWarp - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}

WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);

c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());

if constexpr(kInnerIter.value == 0 && mIter.value == 0 &&
nIter.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
});

__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
});
}
};

public:
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
Expand Down
Loading