diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index b613ae61fb89..9b4548f03677 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1366,77 +1366,83 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } - if (np > 1 && threadIdx.y % np == 0) { - // Combine the meta data for parallel warps via shared memory. - // Warps with threadIdx.y % np != 0 must NOT return early. - // All threads must return simultaneously to avoid race conditions with work on the next tile. - + if (np > 1) { constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1; - const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); - float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; - float2 meta[nmeta]; + float KQ_cmn = 0.0f; + float KQ_crs = 0.0f; + float KQ_cms[nmeta] = {0.0f}; + + if (threadIdx.y % np == 0) { + // Combine the meta data for parallel warps via shared memory. + // Warps with threadIdx.y % np != 0 must NOT return early. + // All threads must return simultaneously to avoid race conditions with work on the next tile. + + const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); + float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; + float2 meta[nmeta]; #pragma unroll - for (int imeta = 0; imeta < nmeta; ++imeta) { - meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2]; - } + for (int imeta = 0; imeta < nmeta; ++imeta) { + meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2]; + } - float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. + KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. #pragma unroll - for (int imeta = 1; imeta < nmeta; ++imeta) { - KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x); - } + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x); + } #pragma unroll - for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset < warp_size) { - KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size)); + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset < warp_size) { + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size)); + } } - } - float KQ_cms[nmeta]; // KQ combine max scale per warp. #pragma unroll - for (int imeta = 0; imeta < nmeta; ++imeta) { - KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn); - } + for (int imeta = 0; imeta < nmeta; ++imeta) { + KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn); + } - float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps. + KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps. #pragma unroll - for (int imeta = 1; imeta < nmeta; ++imeta) { - KQ_crs += KQ_cms[imeta]*meta[imeta].y; - } + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_crs += KQ_cms[imeta]*meta[imeta].y; + } #pragma unroll - for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset < warp_size) { - KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size); + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset < warp_size) { + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size); + } } } + // All warps must hit the same barrier (not divergent if/else branches). __syncthreads(); - // Write back combined meta data: + if (threadIdx.y % np == 0) { + const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); + float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; + + // Write back combined meta data: #pragma unroll - for (int imeta = 0; imeta < nmeta; ++imeta) { - if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) { - // Combined KQ max scale + rowsum. - meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); + for (int imeta = 0; imeta < nmeta; ++imeta) { + if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) { + // Combined KQ max scale + rowsum. + meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); + } } - } - // Combined KQ max + rowsum. - static_assert(cols_per_warp <= warp_size); - if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) { - float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; - dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); - } - if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) { - float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; - dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + // Combined KQ max + rowsum. + static_assert(cols_per_warp <= warp_size); + if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } } - } else if (np > 1) { - // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch. - // Therefore, all other warps also need to execute a __syncthreads(). - // Otherwise the points at which warps synchronize with each other would become misaligned. - __syncthreads(); } #pragma unroll diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 152d61ae1ec7..3c38c8127d44 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -3,6 +3,26 @@ #include "quantize.cuh" #include "mmid.cuh" +// stream-K is faster on Volta+ but unsafe for some batch shapes: +// - split-mode row uses parallel CUDA streams with src1_ncols != ne11 (pool fixup race) +// - llama-server --parallel packs multiple slots into one ubatch (ne12/ne13 > 1), which +// mis-partitions stream-K tiles and crashes the Q2_0 MMQ path (Turing+; see patches/). +static bool ggml_cuda_mmq_use_stream_k(const int cc, const int64_t ne12, const int64_t ne13, + const int64_t src1_ncols, const int64_t ne11) { + const bool arch_ok = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || GGML_CUDA_CC_IS_CDNA(cc); + if (!arch_ok) { + return false; + } + if (ne12 > 1 || ne13 > 1) { + return false; + } + if (src1_ncols != ne11) { + return false; + } + return true; +} + static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { case GGML_TYPE_Q1_0: @@ -121,8 +141,7 @@ void ggml_cuda_mul_mat_q( const int64_t s03 = src0->nb[3] / ts_src0; const int64_t s3 = dst->nb[3] / ts_dst; - const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) - || GGML_CUDA_CC_IS_CDNA(cc); + const bool use_stream_k = ggml_cuda_mmq_use_stream_k(cc, ne12, ne13, ne11, ne11); // TODO: tighter pool buffer size vs q8 path const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4; @@ -250,12 +269,7 @@ void ggml_cuda_op_mul_mat_q( // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - // The stream-k decomposition is only faster for recent NVIDIA GPUs. - // Also its fixup needs to allocate a temporary buffer in the memory pool. - // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. - const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) - || GGML_CUDA_CC_IS_CDNA(cc)) - && src1_ncols == ne11; + const bool use_stream_k = ggml_cuda_mmq_use_stream_k(cc, 1, 1, src1_ncols, ne11); const mmq_args args = { src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,