Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 56 additions & 50 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1366,77 +1366,83 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
}

if (np > 1 && threadIdx.y % np == 0) {
// Combine the meta data for parallel warps via shared memory.
// Warps with threadIdx.y % np != 0 must NOT return early.
// All threads must return simultaneously to avoid race conditions with work on the next tile.

if (np > 1) {
constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1;

const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
float2 meta[nmeta];
float KQ_cmn = 0.0f;
float KQ_crs = 0.0f;
float KQ_cms[nmeta] = {0.0f};

if (threadIdx.y % np == 0) {
// Combine the meta data for parallel warps via shared memory.
// Warps with threadIdx.y % np != 0 must NOT return early.
// All threads must return simultaneously to avoid race conditions with work on the next tile.

const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
float2 meta[nmeta];
#pragma unroll
for (int imeta = 0; imeta < nmeta; ++imeta) {
meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];
}
for (int imeta = 0; imeta < nmeta; ++imeta) {
meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];
}

float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
#pragma unroll
for (int imeta = 1; imeta < nmeta; ++imeta) {
KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x);
}
for (int imeta = 1; imeta < nmeta; ++imeta) {
KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x);
}
#pragma unroll
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
if (offset < warp_size) {
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
if (offset < warp_size) {
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));
}
}
}

float KQ_cms[nmeta]; // KQ combine max scale per warp.
#pragma unroll
for (int imeta = 0; imeta < nmeta; ++imeta) {
KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn);
}
for (int imeta = 0; imeta < nmeta; ++imeta) {
KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn);
}

float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps.
KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps.
#pragma unroll
for (int imeta = 1; imeta < nmeta; ++imeta) {
KQ_crs += KQ_cms[imeta]*meta[imeta].y;
}
for (int imeta = 1; imeta < nmeta; ++imeta) {
KQ_crs += KQ_cms[imeta]*meta[imeta].y;
}
#pragma unroll
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
if (offset < warp_size) {
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
if (offset < warp_size) {
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);
}
}
}

// All warps must hit the same barrier (not divergent if/else branches).
__syncthreads();

// Write back combined meta data:
if (threadIdx.y % np == 0) {
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;

// Write back combined meta data:
#pragma unroll
for (int imeta = 0; imeta < nmeta; ++imeta) {
if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {
// Combined KQ max scale + rowsum.
meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
for (int imeta = 0; imeta < nmeta; ++imeta) {
if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {
// Combined KQ max scale + rowsum.
meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
}
}
}

// Combined KQ max + rowsum.
static_assert(cols_per_warp <= warp_size);
if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
// Combined KQ max + rowsum.
static_assert(cols_per_warp <= warp_size);
if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
}
} else if (np > 1) {
// Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
// Therefore, all other warps also need to execute a __syncthreads().
// Otherwise the points at which warps synchronize with each other would become misaligned.
__syncthreads();
}

#pragma unroll
Expand Down
30 changes: 22 additions & 8 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,26 @@
#include "quantize.cuh"
#include "mmid.cuh"

// stream-K is faster on Volta+ but unsafe for some batch shapes:
// - split-mode row uses parallel CUDA streams with src1_ncols != ne11 (pool fixup race)
// - llama-server --parallel packs multiple slots into one ubatch (ne12/ne13 > 1), which
// mis-partitions stream-K tiles and crashes the Q2_0 MMQ path (Turing+; see patches/).
static bool ggml_cuda_mmq_use_stream_k(const int cc, const int64_t ne12, const int64_t ne13,
const int64_t src1_ncols, const int64_t ne11) {
const bool arch_ok = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc);
if (!arch_ok) {
return false;
}
if (ne12 > 1 || ne13 > 1) {
return false;
}
if (src1_ncols != ne11) {
return false;
}
return true;
}

static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
switch (args.type_x) {
case GGML_TYPE_Q1_0:
Expand Down Expand Up @@ -121,8 +141,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s3 = dst->nb[3] / ts_dst;

const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc);
const bool use_stream_k = ggml_cuda_mmq_use_stream_k(cc, ne12, ne13, ne11, ne11);

// TODO: tighter pool buffer size vs q8 path
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
Expand Down Expand Up @@ -250,12 +269,7 @@ void ggml_cuda_op_mul_mat_q(
// nrows_dst == nrows of the matrix that the kernel writes into
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;

// The stream-k decomposition is only faster for recent NVIDIA GPUs.
// Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc))
&& src1_ncols == ne11;
const bool use_stream_k = ggml_cuda_mmq_use_stream_k(cc, 1, 1, src1_ncols, ne11);
const mmq_args args = {
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
Expand Down