diff --git a/.github/workflows/aiter-release.yaml b/.github/workflows/aiter-release.yaml index 1807fbf440..0255edb264 100644 --- a/.github/workflows/aiter-release.yaml +++ b/.github/workflows/aiter-release.yaml @@ -67,6 +67,16 @@ on: description: 'Use pytorch/manylinux2_28-builder ROCm image (AlmaLinux 8 + devtoolset, glibc 2.28). Produces wheels ABI-compatible with vLLM/Ubuntu 22 containers.' type: boolean default: false + torch_pin: + description: 'Optional torch version pin for the manylinux build (e.g. 2.10.0+rocm7.1). Empty = latest available for the detected ROCm flavor.' + type: string + required: false + default: '' + torch_index_url: + description: 'Optional override for the torch wheel index URL. Empty = auto-derive from the manylinux builder image tag (https://download.pytorch.org/whl/rocmX.Y).' + type: string + required: false + default: '' workflow_call: inputs: release_type: @@ -111,6 +121,16 @@ on: type: boolean required: false default: false + torch_pin: + description: 'Optional torch version pin for the manylinux build (e.g. 2.10.0+rocm7.1). Empty = latest.' + type: string + required: false + default: '' + torch_index_url: + description: 'Optional torch index URL override. Empty = auto-derive from builder image tag.' + type: string + required: false + default: '' outputs: wheel_names: description: 'Space-separated list of built wheel filenames' @@ -145,6 +165,8 @@ jobs: RELEASE_TYPE: ${{ inputs.release_type || github.event.inputs.release_type }} ADD_DATE_STAMP: ${{ inputs.add_date_stamp || github.event.inputs.add_date_stamp }} USE_MANYLINUX: ${{ inputs.use_manylinux || github.event.inputs.use_manylinux || (startsWith(matrix.docker_image, 'pytorch/manylinux') && 'true') || 'false' }} + TORCH_PIN: ${{ inputs.torch_pin || github.event.inputs.torch_pin }} + TORCH_INDEX_URL: ${{ inputs.torch_index_url || github.event.inputs.torch_index_url }} steps: - name: Checkout aiter repo @@ -301,17 +323,32 @@ jobs: IMG="${BUILD_DOCKER_IMAGE}" ROCM_TAG="${IMG##*:}" # rocm7.2 / rocm7.1 / rocm7.0 ROCM_NUM="${ROCM_TAG#rocm}" # 7.2 - TORCH_INDEX="https://download.pytorch.org/whl/rocm${ROCM_NUM}" + # Allow caller to override the torch wheel index (e.g. pin a release + # to a specific ROCm flavor's PyTorch ABI). Defaults preserve the + # legacy auto-derived behavior. + if [ -n "${TORCH_INDEX_URL}" ]; then + TORCH_INDEX="${TORCH_INDEX_URL}" + else + TORCH_INDEX="https://download.pytorch.org/whl/rocm${ROCM_NUM}" + fi + # Optional torch version pin (e.g. 2.10.0+rocm7.1). Empty = latest. + if [ -n "${TORCH_PIN}" ]; then + TORCH_SPEC="torch==${TORCH_PIN}" + else + TORCH_SPEC="torch" + fi echo "Torch index: ${TORCH_INDEX}" + echo "Torch spec: ${TORCH_SPEC}" docker exec \ -w /workspace \ -e PYBIN="${PYBIN}" \ -e TORCH_INDEX="${TORCH_INDEX}" \ + -e TORCH_SPEC="${TORCH_SPEC}" \ aiter_build_${{ matrix.python_version }} \ bash -c ' set -e ${PYBIN}/pip install --upgrade --timeout=60 --retries=10 pip - ${PYBIN}/pip install --timeout=60 --retries=10 --index-url "${TORCH_INDEX}" torch + ${PYBIN}/pip install --timeout=60 --retries=10 --index-url "${TORCH_INDEX}" "${TORCH_SPEC}" # flydsl publishes only manylinux_2_35 wheels which cannot install # on AlmaLinux 8 (glibc 2.28). FlyDSL AOT pre-compilation in # setup.py is wrapped in try/except and is skipped gracefully when diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 5348b577ed..fdf4bb7fcc 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 5348b577ed7a5d88d350d88fd720b882176466ae +Subproject commit fdf4bb7fcc984811cef48ce817d89aac064b984a diff --git a/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv b/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv index 5587bed4c7..1deaa684f9 100644 --- a/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv +++ b/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv @@ -1,58 +1,58 @@ gfx,cu_num,M,N,K,bias,dtype,outdtype,scaleAB,bpreshuffle,libtype,solidx,splitK,us,kernelName,err_ratio,tflops,bw -gfx950,256,1,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,9,4.6262,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0078,0.16,160.67 -gfx950,256,2,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,9,4.6019,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0156,0.32,162.83 -gfx950,256,4,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,15,4.7202,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0137,0.62,161.29 -gfx950,256,8,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,15,4.7061,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0273,1.25,166.89 -gfx950,256,16,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,15,4.7651,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0273,2.48,174.93 -gfx950,256,32,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,42,15,4.6463,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp2_block_n_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0237,5.08,200.11 -gfx950,256,48,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,39,15,4.7651,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0262,7.43,215.33 -gfx950,256,64,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,188,15,4.7406,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0243,9.95,236.74 -gfx950,256,80,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,239,15,4.8623,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0253,12.13,250.61 -gfx950,256,96,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,242,15,5.3055,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp2_block_n_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0247,13.34,247.82 -gfx950,256,112,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,240,15,5.2015,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0246,15.88,271.28 -gfx950,256,128,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,205,9,5.6343,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0177,16.75,267.53 -gfx950,256,256,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,370,9,6.5478,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k9_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.018,28.83,347.81 -gfx950,256,1,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,243,6,8.5347,flydsl_gemm2_abf16_wbf16_bf16_t16x64x160_split_k6_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0141,1.73,1729.0 -gfx950,256,2,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,243,6,7.904,flydsl_gemm2_abf16_wbf16_bf16_t16x64x160_split_k6_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0152,3.73,1868.34 -gfx950,256,4,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,245,15,7.9144,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k15_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0266,7.45,1868.63 -gfx950,256,8,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,245,15,8.2005,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k15_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0251,14.39,1808.75 -gfx950,256,16,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,268,9,8.614,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0182,27.39,1732.03 -gfx950,256,32,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,10.0478,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0099,46.96,1502.2 -gfx950,256,48,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.5933,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0073,66.81,1441.27 -gfx950,256,64,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.7011,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0074,88.19,1443.02 -gfx950,256,128,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,11.7861,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0074,160.14,1369.26 -gfx950,256,256,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,2,15.3912,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0041,245.26,1139.02 -gfx950,256,1,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,139,16,6.0968,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe4_ur16_gfx950,0.0212,1.93,1936.48 -gfx950,256,1,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,388,16,8.3638,flydsl_gemm2_abf16_wbf16_bf16_t16x192x64_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe2_ur8_gfx950,0.0316,2.82,2822.51 -gfx950,256,2,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,138,16,6.0853,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe2_ur8_gfx950,0.0224,3.88,1941.76 -gfx950,256,2,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,139,16,8.2195,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe4_ur16_gfx950,0.0269,5.74,2873.76 -gfx950,256,4,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,140,16,6.1308,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe4_ur8_gfx950,0.0224,7.7,1930.56 -gfx950,256,4,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,138,16,8.2605,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe2_ur8_gfx950,0.0292,11.42,2862.87 -gfx950,256,8,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,771,8,6.726,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k8_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe4_ur8_gfx950,0.0161,14.03,1765.59 -gfx950,256,8,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,792,8,8.6182,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k8_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_ur16_gfx950,0.0195,21.9,2750.53 -gfx950,256,16,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,750,4,6.9944,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k4_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0103,26.98,1709.11 -gfx950,256,16,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,798,8,9.0841,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k8_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0205,41.55,2621.74 -gfx950,256,32,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,145,4,8.0743,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k4_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0104,46.75,1500.05 -gfx950,256,32,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,164,4,10.4391,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k4_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0132,72.32,2302.83 -gfx950,256,48,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,353,4,9.5624,flydsl_gemm2_abf16_wbf16_bf16_t64x64x256_split_k4_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0103,59.21,1283.11 -gfx950,256,48,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,5,12.1652,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0136,93.09,1994.43 -gfx950,256,64,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,10.045,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0066,75.16,1237.16 -gfx950,256,64,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,5,12.6691,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0136,119.18,1932.73 -gfx950,256,80,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,5,12.8662,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0137,146.7,1920.47 -gfx950,256,96,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,5,13.1474,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0136,172.27,1896.37 -gfx950,256,112,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,17.299,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0048,152.75,1454.16 -gfx950,256,128,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,11.5282,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0035,130.98,1132.7 -gfx950,256,128,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,17.0438,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0048,177.18,1489.04 -gfx950,256,256,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,6,1,15.42,_ZN5aiter37bf16gemm_fp32bf16_tn_64x64_pf3_splitkE,0.0,195.84,928.64 -gfx950,256,1,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,247,9,10.4898,flydsl_gemm2_abf16_wbf16_bf16_t16x64x160_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0182,2.81,2812.94 -gfx950,256,2,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,247,9,9.6817,flydsl_gemm2_abf16_wbf16_bf16_t16x64x160_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0194,6.09,3049.38 -gfx950,256,4,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,271,9,10.2067,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0201,11.56,2895.67 -gfx950,256,8,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,271,9,11.2498,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0189,20.97,2632.86 -gfx950,256,16,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,271,9,11.9689,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0191,39.42,2485.37 -gfx950,256,32,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.2241,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0077,77.2,2454.43 -gfx950,256,48,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,3,12.9649,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0076,109.19,2333.93 -gfx950,256,64,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,13.6686,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0077,138.09,2232.5 -gfx950,256,80,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,3,13.957,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0076,169.04,2204.71 -gfx950,256,96,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,3,14.3466,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0076,197.34,2162.69 -gfx950,256,112,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,4,1,19.3332,_ZN5aiter37bf16gemm_fp32bf16_tn_48x64_pf3_splitkE,0.0,170.85,1618.11 -gfx950,256,128,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,4,1,19.5488,_ZN5aiter37bf16gemm_fp32bf16_tn_48x64_pf3_splitkE,0.0,193.1,1613.36 +gfx950,256,1,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,14,4.9558,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0234,0.15,149.99 +gfx950,256,2,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,4.9466,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0195,0.3,151.48 +gfx950,256,4,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,14,4.9687,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0156,0.59,153.23 +gfx950,256,8,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,14,4.9927,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0176,1.18,157.31 +gfx950,256,16,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,5.031,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0171,2.34,165.68 +gfx950,256,32,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,4.6354,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0203,5.09,200.59 +gfx950,256,48,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,14,5.2547,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0212,6.73,195.26 +gfx950,256,64,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,13,5.3561,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.021,8.81,209.54 +gfx950,256,80,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,13,5.6419,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0218,10.45,215.98 +gfx950,256,96,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,9,5.7166,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0163,12.38,230.0 +gfx950,256,112,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,9,5.9183,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0165,13.95,238.43 +gfx950,256,128,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,9,5.97,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0172,15.81,252.48 +gfx950,256,256,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,6,7.0187,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0126,26.89,324.47 +gfx950,256,1,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.6772,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0105,1.52,1524.87 +gfx950,256,2,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.8371,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0107,3.0,1501.19 +gfx950,256,4,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.8551,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0118,5.98,1500.66 +gfx950,256,8,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.9035,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0101,11.91,1497.72 +gfx950,256,16,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,10.0897,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.01,23.38,1478.7 +gfx950,256,32,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,10.2307,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.01,46.12,1475.34 +gfx950,256,48,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.7334,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0073,65.94,1422.46 +gfx950,256,64,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.6753,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0074,88.4,1446.51 +gfx950,256,128,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,11.6504,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0074,162.01,1385.21 +gfx950,256,256,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,2,15.3742,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0041,245.53,1140.28 +gfx950,256,1,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.0917,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0083,1.3,1298.58 +gfx950,256,1,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,10.0871,auto,0.0,2.34,2340.31 +gfx950,256,2,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,9.2607,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0056,2.55,1275.95 +gfx950,256,2,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,9.6663,auto,0.0,4.88,2443.63 +gfx950,256,4,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.1797,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0091,5.14,1289.36 +gfx950,256,4,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,10.1637,auto,0.0,9.29,2326.79 +gfx950,256,8,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.2315,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0092,10.22,1286.39 +gfx950,256,8,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,5,11.4653,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0132,16.46,2067.51 +gfx950,256,16,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,9.3253,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0062,20.24,1281.91 +gfx950,256,16,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,5,11.591,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0136,32.57,2054.71 +gfx950,256,32,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,9.2671,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0065,40.73,1306.98 +gfx950,256,32,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,5,11.7821,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0137,64.08,2040.33 +gfx950,256,48,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,4,9.8145,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.009,57.69,1250.15 +gfx950,256,48,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,5,12.1519,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0136,93.19,1996.61 +gfx950,256,64,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,4,10.1075,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.009,74.69,1229.51 +gfx950,256,64,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,5,12.6689,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0137,119.19,1932.76 +gfx950,256,80,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,5,13.1447,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0137,143.59,1879.78 +gfx950,256,96,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,5,13.5787,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0136,166.8,1836.14 +gfx950,256,112,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,17.4374,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0048,151.54,1442.62 +gfx950,256,128,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,11.3635,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0035,132.88,1149.12 +gfx950,256,128,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,17.3233,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0048,174.33,1465.01 +gfx950,256,256,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,6,1,15.5453,_ZN5aiter37bf16gemm_fp32bf16_tn_64x64_pf3_splitkE,0.0,194.26,921.15 +gfx950,256,1,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,11.5513,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0078,2.55,2554.45 +gfx950,256,2,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,11.4061,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0075,5.17,2588.37 +gfx950,256,4,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.0536,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0075,9.79,2451.98 +gfx950,256,8,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,11.96,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0077,19.73,2476.52 +gfx950,256,16,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.2044,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0077,38.66,2437.42 +gfx950,256,32,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.3457,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0078,76.44,2430.26 +gfx950,256,48,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,3,13.0203,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0076,108.72,2324.0 +gfx950,256,64,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,13.4435,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0077,140.4,2269.89 +gfx950,256,80,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,3,14.2374,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0076,165.71,2161.29 +gfx950,256,96,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,3,14.4747,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0076,195.59,2143.55 +gfx950,256,112,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,4,1,19.1801,_ZN5aiter37bf16gemm_fp32bf16_tn_48x64_pf3_splitkE,0.0,172.21,1631.02 +gfx950,256,128,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,4,1,19.3026,_ZN5aiter37bf16gemm_fp32bf16_tn_48x64_pf3_splitkE,0.0,195.56,1633.94 diff --git a/csrc/cpp_itfs/mha_fwd_batch_prefill.cu b/csrc/cpp_itfs/mha_fwd_batch_prefill.cu index 2c5da43ef2..7994e7b2d9 100644 --- a/csrc/cpp_itfs/mha_fwd_batch_prefill.cu +++ b/csrc/cpp_itfs/mha_fwd_batch_prefill.cu @@ -47,7 +47,13 @@ float mha_batch_prefill(mha_batch_prefill_args args, int head_size_q = args.hdim_q; int head_size_v = args.hdim_v; bool has_dropout = args.p_drop > 0.f; - auto traits = get_mha_batch_prefill_traits(head_size_q, + + // The kUseGlobalLoad decision (>2GB KV cache → use `global_load_lds_*` + // instead of SRD `buffer_load_*`) is made per-arm inside the auto-generated + // dispatcher in fmha_batch_prefill_api.cpp, where each arm knows its own + // compile-time bn0 and dtype element size. The wrapper just forwards args; + // no runtime trait field for it. + auto traits = get_mha_batch_prefill_traits(head_size_q, head_size_v, q_dtype_str, is_group_mode, diff --git a/csrc/kernels/topk_per_row_kernels.cu b/csrc/kernels/topk_per_row_kernels.cu index 89331c52df..6edf377ca8 100644 --- a/csrc/kernels/topk_per_row_kernels.cu +++ b/csrc/kernels/topk_per_row_kernels.cu @@ -1435,6 +1435,18 @@ __global__ void radix_topk_one_block_kernel(T const* in, return; } + // Long-row path: kernel internally treats in[0..row_len) as the valid + // window. Shift `in` (and `in_idx`) up by `rowStart` so that the radix + // pipeline reads the actual valid columns rather than the masked-out + // [0, rowStart) prefix that fp8_mqa_logits fills with -inf. Internal + // indices i are then relative to rowStart; we add rowStart back to + // out_idx at the end of this branch to get absolute column indices. + in += rowStart; + if(in_idx) + { + in_idx += rowStart; + } + const IdxT buf_len = calc_buf_len(len); bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); @@ -1522,6 +1534,23 @@ __global__ void radix_topk_one_block_kernel(T const* in, break; } } + + // Long-row path was using rowStart-relative indices inside the radix + // pipeline (because we shifted `in` by rowStart above). Translate them + // back to absolute column indices for downstream consumers. Sentinels + // (-1, written when fewer than k valid candidates exist) are preserved. + if(rowStart > 0) + { + __syncthreads(); + for(int i = threadIdx.x; i < k; i += BlockSize) + { + IdxT v = out_idx[i]; + if(v >= 0) + { + out_idx[i] = v + rowStart; + } + } + } } inline size_t calc_aligned_size(std::vector const& sizes) diff --git a/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu b/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu index cd8dd1c531..15a4878ed9 100644 --- a/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu +++ b/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu @@ -817,7 +817,13 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d] has_lse, qscale_type, false); - TORCH_CHECK(t >= 0, "invalid argument for batch_prefill"); + TORCH_CHECK(t >= 0, + "invalid argument for batch_prefill: no matching kernel found. " + "page_size=", args.page_block_size, + ", num_pages=", args.num_total_pages, + ", dtype=", dtype_str, + ". If KV cache exceeds 2GB (INT32_MAX byte offset) with page_size < kN0, " + "CDNA3+ GPU (MI300/MI350) is required."); } else { diff --git a/op_tests/test_batch_prefill.py b/op_tests/test_batch_prefill.py index ab99206988..ee5489fec4 100644 --- a/op_tests/test_batch_prefill.py +++ b/op_tests/test_batch_prefill.py @@ -1705,133 +1705,319 @@ def reference_attention_kv_blockscale( return output.to(torch.bfloat16) -@pytest.mark.parametrize( - "num_blocks,page_size", - [ - (5000, 1024), # ~10GB KV cache - (10000, 1024), # ~20GB KV cache - ], -) -@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("kv_cache_size_gb", [4.5]) +@pytest.mark.parametrize("page_size", [1, 16, 1024]) +@pytest.mark.parametrize("num_qo_heads,num_kv_heads", [(8, 8), (16, 8)]) @pytest.mark.parametrize("head_dim", [128]) -@pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("input_dtype", ["bf16", "fp8"]) +# scatter_pages=True: adjacent logical tokens map to physically distant pages, +# stress-testing the paged KV cache addressing when pages span large physical distances. +@pytest.mark.parametrize("scatter_pages", [False, True]) +@pytest.mark.parametrize("kv_layout", ["linear", "vectorized"]) def test_batch_prefill_large_kvcache( - num_blocks, + batch_size, + kv_cache_size_gb, page_size, + num_qo_heads, num_kv_heads, head_dim, causal, input_dtype, + scatter_pages, + kv_layout, ): """ Test that batch prefill produces correct results with large KV caches - whose element offsets exceed the INT32_MAX boundary (~4GB for bf16). + whose element offsets exceed the INT32_MAX boundary. + + Uses the full KV cache for attention with pages spanning the overflow + boundary, and compares kernel output against SDPA reference. + For page_size < kN0 (128), this validates the per-tile SRD rebase path. + + Args: + batch_size: Number of sequences. >1 partitions the >2GB page pool + across batches, exercising the per-sequence SRD rebase path. + scatter_pages: If True, interleave page indices so adjacent logical + tokens map to physically distant pages (stress-tests rebase). + kv_layout: "linear" or "vectorized" KV cache memory layout. """ + # page_size=1 only supports linear layout (3D tensor) + if page_size == 1 and kv_layout == "vectorized": + pytest.skip("page_size=1 does not support vectorized layout") + torch.manual_seed(42) + torch.cuda.empty_cache() is_fp8 = input_dtype == "fp8" dtype = torch.bfloat16 - num_qo_heads = num_kv_heads # MHA (no GQA) for simplicity - stride_per_page = page_size * num_kv_heads * head_dim # elements per page block + # Compute num_blocks from target KV cache size + elem_size = 1 if is_fp8 else 2 # fp8=1 byte, bf16=2 bytes + elements_per_block = page_size * num_kv_heads * head_dim + target_bytes = int(kv_cache_size_gb * 1024**3) + num_blocks = target_bytes // (elements_per_block * elem_size) + + # Verify this config triggers overflow + stride_per_page = elements_per_block max_offset = (num_blocks - 1) * stride_per_page INT32_MAX = 2**31 - 1 - if max_offset <= INT32_MAX: pytest.skip( f"max_offset {max_offset} doesn't exceed INT32_MAX, not an overflow test" ) - # Check available GPU memory -- skip if not enough + # Check available GPU memory free_mem = torch.cuda.mem_get_info()[0] - elem_size = 1 if is_fp8 else 2 # fp8=1 byte, bf16=2 bytes - required_mem = 2 * num_blocks * page_size * num_kv_heads * head_dim * elem_size - if free_mem < required_mem * 1.1: # 10% headroom + # Per-batch page partition: uniform split, remainder absorbed by the last + # sequence to keep all kv_indptr deltas > 0 (zero-length sequences would be + # skipped by the kernel's per-batch dispatch and hide any rebase bug). + blocks_per_seq = [num_blocks // batch_size] * batch_size + blocks_per_seq[-1] += num_blocks % batch_size + kv_lens_per_seq = [bps * page_size for bps in blocks_per_seq] + max_kv_len_per_seq = max(kv_lens_per_seq) + # Causal with attn_mask forces SDPA math backend which materializes + # [H_q, qo_len, kv_len] score + mask tensors. Magnitudes empirically chosen: + # non-causal: 1024 -- flash backend, no full score matrix, headroom is large + # causal: 128 -- math backend cliff: 3x [H_q, qo, kv] fp32 buffers must + # fit alongside K/V cache (kv_len up to ~5GB at this scale) + # qo_len is per-batch; total qo tokens = batch_size * qo_len. + qo_len = min(128, max_kv_len_per_seq) if causal else min(1024, max_kv_len_per_seq) + total_qo_len = batch_size * qo_len + # SDPA causal with attn_mask forces math backend: expanded mask + score matrix + # + softmax intermediates, each [1, H_q, qo, kv_per_batch] fp32. ~3x overhead. + # The per-batch SDPA loop allocates one batch's worth at a time (kv_len + # divided by batch_size), then frees before the next iteration. + sdpa_causal_mem = ( + 3 * num_qo_heads * qo_len * max_kv_len_per_seq * 4 if causal else 0 + ) + # GQA expands K/V from H_kv to H_q heads for SDPA reference + gqa_ratio = num_qo_heads // num_kv_heads + # Sequential pages reuse K/V directly; scattered need a gathered copy + gathered_mem = 2 * num_blocks * elements_per_block * 2 if scatter_pages else 0 + required_mem = ( + 2 * num_blocks * elements_per_block * 2 # K/V bf16 + + 2 * num_blocks * elements_per_block * elem_size # kernel K/V (fp8 or bf16) + + gathered_mem + + 2 * num_blocks * elements_per_block * 2 * (gqa_ratio - 1) # GQA K/V expansion + + sdpa_causal_mem + ) + if free_mem < required_mem * 1.1: pytest.skip( f"Not enough GPU memory: need {required_mem / 1e9:.1f}GB, " f"have {free_mem / 1e9:.1f}GB" ) - # Allocate KV caches in linear layout: [num_blocks, page_size, num_kv_heads, head_dim] - k_cache_bf16 = torch.randn( - num_blocks, page_size, num_kv_heads, head_dim, device="cuda", dtype=dtype - ) - v_cache_bf16 = torch.randn( - num_blocks, page_size, num_kv_heads, head_dim, device="cuda", dtype=dtype + # Allocate KV caches in bf16 + # page_size=1 uses 3D linear layout [num_tokens, num_kv_heads, head_dim] + # page_size>1 uses 4D paged layout [num_blocks, page_size, num_kv_heads, head_dim] + if page_size == 1: + kv_shape = (num_blocks, num_kv_heads, head_dim) + else: + kv_shape = (num_blocks, page_size, num_kv_heads, head_dim) + + k_cache_bf16 = torch.randn(*kv_shape, device="cuda", dtype=dtype) + if scatter_pages: + # Use page-dependent V values to detect address wrapping bugs. + # With random V, wrong addresses read statistically similar data -> false pass. + # With V[page] ? page_index, wrapped addresses (low pages) give ~0 instead of + # the correct ~1 for high pages, making the error detectable. + page_vals = ( + torch.arange(num_blocks, device="cuda", dtype=torch.float32) / num_blocks + ) + if page_size == 1: + v_cache_bf16 = page_vals.view(-1, 1, 1).expand(*kv_shape).to(dtype) + else: + v_cache_bf16 = page_vals.view(-1, 1, 1, 1).expand(*kv_shape).to(dtype) + else: + v_cache_bf16 = torch.randn(*kv_shape, device="cuda", dtype=dtype) + + # Query: flat [total_qo_len, H_q, D] layout matching mha_batch_prefill_func + # input contract. Per-batch slices recovered via cu_seqlens_q in the loop below. + q_bf16 = torch.randn( + total_qo_len, num_qo_heads, head_dim, device="cuda", dtype=dtype ) - if is_fp8: - k_cache, k_descale = per_tensor_quant(k_cache_bf16, quant_dtype=dtypes.fp8) - v_cache, v_descale = per_tensor_quant(v_cache_bf16, quant_dtype=dtypes.fp8) + # Page indices: since the buffer exceeds INT32_MAX elements, these pages + # naturally span the overflow boundary. + overflow_page = INT32_MAX // stride_per_page + + if scatter_pages: + # Interleave: [0, N-1, 1, N-2, 2, N-3, ...] so adjacent logical tokens + # map to physically distant pages (low <-> high, spanning >2GB gap). + lo = torch.arange(0, num_blocks, 2, dtype=torch.int32) + hi = torch.arange(num_blocks - 1, -1, -2, dtype=torch.int32) + page_indices = torch.zeros(num_blocks, dtype=torch.int32) + page_indices[0::2] = lo[: (num_blocks + 1) // 2] + page_indices[1::2] = hi[: num_blocks // 2] else: - k_cache = k_cache_bf16 - v_cache = v_cache_bf16 + # Sequential: [0, 1, 2, ..., N-1] + page_indices = torch.arange(num_blocks, dtype=torch.int32) - # Test pages that span the overflow boundary - qo_len = 1 - kv_len = page_size # one full page + # --- Step 1: Compute SDPA reference FIRST (while bf16 data is alive) --- + # Per-batch loop: each iteration gathers its slice of pages, runs SDPA, + # and frees intermediates before the next batch. Keeps peak memory at + # one batch's worth (vs. materializing the full multi-batch score tensor). + o_ref_list = [] + page_offset = 0 + for b in range(batch_size): + n_blocks_b = blocks_per_seq[b] + page_slice_b = page_indices[page_offset : page_offset + n_blocks_b] + page_offset += n_blocks_b + kv_len_b = kv_lens_per_seq[b] + + # Always gather: even sequential pages need a per-batch slice to keep + # the multi-batch SDPA references aligned with the kernel's per-batch + # SRD rebase. (For batch_size=1 + sequential, this is just an alias + # of the full cache via the index slice.) + if page_size == 1: + k_ref_b = k_cache_bf16[page_slice_b.long()] + v_ref_b = v_cache_bf16[page_slice_b.long()] + else: + k_ref_b = k_cache_bf16[page_slice_b.long()].reshape( + -1, num_kv_heads, head_dim + ) + v_ref_b = v_cache_bf16[page_slice_b.long()].reshape( + -1, num_kv_heads, head_dim + ) - q_bf16 = torch.randn(qo_len, num_qo_heads, head_dim, device="cuda", dtype=dtype) - if is_fp8: - q, q_descale = per_tensor_quant(q_bf16, quant_dtype=dtypes.fp8) - else: - q = q_bf16 - cu_seqlens_q = torch.tensor([0, qo_len], device="cuda", dtype=torch.int32) + q_b = q_bf16[b * qo_len : (b + 1) * qo_len] - # Test at several page indices: before, at, and after the overflow boundary - overflow_page = INT32_MAX // stride_per_page - test_pages = [ - 0, - overflow_page - 1, - overflow_page, - overflow_page + 1, - num_blocks - 1, - ] - test_pages = [p for p in test_pages if 0 <= p < num_blocks] - # Remove duplicates while preserving order - test_pages = list(dict.fromkeys(test_pages)) - - threshold = 0.055 if is_fp8 else 0.01 - - for page_idx in test_pages: - offset = page_idx * stride_per_page - label = "OVERFLOW" if offset > INT32_MAX else "safe" - - kv_indptr = torch.tensor([0, 1], device="cuda", dtype=torch.int32) - kv_page_indices = torch.tensor([page_idx], device="cuda", dtype=torch.int32) - kv_last_page_lens = torch.tensor([page_size], device="cuda", dtype=torch.int32) - - extra_kwargs = {} - if is_fp8: - extra_kwargs = dict( - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale + # SDPA expects [batch, heads, seq, dim] + q_sdpa = q_b.unsqueeze(0).transpose(1, 2) + k_sdpa = k_ref_b.unsqueeze(0).transpose(1, 2) + v_sdpa = v_ref_b.unsqueeze(0).transpose(1, 2) + del k_ref_b, v_ref_b + + # GQA: manual K/V head expansion (see comment in non-multi-batch + # equivalent removed in this commit -- using enable_gqa=True with + # causal attn_mask forces SDPA math backend and OOMs for large kv_len). + if num_qo_heads != num_kv_heads: + ratio = num_qo_heads // num_kv_heads + k_sdpa = k_sdpa.repeat_interleave(ratio, dim=1) + v_sdpa = v_sdpa.repeat_interleave(ratio, dim=1) + + sdpa_kwargs = {} + if causal: + # CK batch prefill causal: Q is at the END of the KV context. + # Q[i] can see K[j] where j <= (kv_len_b - qo_len) + i. + offset = kv_len_b - qo_len + row_idx = torch.arange(qo_len, device="cuda").unsqueeze(1) + col_idx = torch.arange(kv_len_b, device="cuda").unsqueeze(0) + sdpa_kwargs["attn_mask"] = col_idx <= (offset + row_idx) + + o_b = ( + torch.nn.functional.scaled_dot_product_attention( + q_sdpa, k_sdpa, v_sdpa, **sdpa_kwargs ) + .squeeze(0) + .transpose(0, 1) + ) + o_ref_list.append(o_b) + del q_sdpa, k_sdpa, v_sdpa, sdpa_kwargs + torch.cuda.empty_cache() - result = aiter.mha_batch_prefill_func( - q, - k_cache, - v_cache, - cu_seqlens_q, - kv_indptr, - kv_page_indices, - qo_len, - kv_len, - causal=causal, - kv_last_page_lens=kv_last_page_lens, - **extra_kwargs, + o_ref = torch.cat(o_ref_list, dim=0) + del o_ref_list + torch.cuda.empty_cache() + + # --- Step 2: Prepare kernel inputs (quantize for FP8, free bf16 after) --- + if is_fp8: + k_cache_kernel, k_descale = per_tensor_quant( + k_cache_bf16, quant_dtype=dtypes.fp8 + ) + v_cache_kernel, v_descale = per_tensor_quant( + v_cache_bf16, quant_dtype=dtypes.fp8 + ) + q_kernel, q_descale = per_tensor_quant(q_bf16, quant_dtype=dtypes.fp8) + del k_cache_bf16, v_cache_bf16, q_bf16 + torch.cuda.empty_cache() + else: + k_cache_kernel = k_cache_bf16 + v_cache_kernel = v_cache_bf16 + q_kernel = q_bf16 + + # Apply vectorized layout transformation if needed + if kv_layout == "vectorized" and page_size > 1: + kv_vector_size = 16 // k_cache_kernel.element_size() + k_cache_kernel, v_cache_kernel = apply_kv_layout( + k_cache_kernel, + v_cache_kernel, + num_kv_heads, + head_dim, + page_size, + kv_vector_size, + "vectorized", ) - out = result[0] if isinstance(result, (list, tuple)) else result - # Reference: direct attention on the original bf16 data - k_page = k_cache_bf16[page_idx] # [page_size, num_kv_heads, head_dim] - v_page = v_cache_bf16[page_idx] - o_ref = ref_masked_attention(q_bf16, k_page, v_page, causal=causal) + # Multi-batch indptrs: cu_seqlens_q is the cumulative qo offset per batch + # (uniform qo_len), kv_indptr is the cumulative page count per batch. + cu_seqlens_q = torch.tensor( + [0] + [(i + 1) * qo_len for i in range(batch_size)], + device="cuda", + dtype=torch.int32, + ) + kv_indptr = torch.tensor( + [0] + list(itertools.accumulate(blocks_per_seq)), + device="cuda", + dtype=torch.int32, + ) + # +256 padding is a batch_prefill ABI requirement: the kernel may speculatively + # read up to 256 entries past the last valid page index (one bn0=256 tile worth) + # before the bounds check kicks in. Padding with 0 keeps reads in-bounds; the + # values are masked out by causal/length logic and never affect the output. + kv_page_indices = torch.nn.functional.pad(page_indices, (0, 256), value=0).to( + "cuda" + ) + kv_last_page_lens = torch.tensor( + [page_size] * batch_size, device="cuda", dtype=torch.int32 + ) + + # --- Step 3: Run CK kernel --- + extra_kwargs = {} + if is_fp8: + extra_kwargs = dict( + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale + ) - max_diff = (out - o_ref).abs().max().item() - assert max_diff < threshold, ( - f"[{input_dtype}] page {page_idx} (offset={offset}, {label}): " - f"max_diff={max_diff} exceeds threshold {threshold}" + result = aiter.mha_batch_prefill_func( + q_kernel, + k_cache_kernel, + v_cache_kernel, + cu_seqlens_q, + kv_indptr, + kv_page_indices, + qo_len, + max_kv_len_per_seq, + causal=causal, + kv_last_page_lens=kv_last_page_lens, + **extra_kwargs, + ) + # Synchronize immediately to catch async GPU faults from CK kernel before + # they cascade. Without this sync, an async fault can surface inside the + # next test's torch.cuda.empty_cache() (or any other CUDA call), causing + # the failure to be misattributed to that unrelated test -- and on bad + # faults the cascade can trigger a GPU reset that wipes out subsequent + # test results too. + torch.cuda.synchronize() + out = result[0] if isinstance(result, (list, tuple)) else result + + # Compare kernel output vs SDPA reference + if is_fp8: + verify_fp8_output(out, o_ref, threshold=0.055) + else: + rtol, atol = get_tolerances(dtype) + torch.testing.assert_close( + out, + o_ref, + rtol=rtol, + atol=atol, + msg=lambda msg: ( + f"[{input_dtype}] batch_size={batch_size} " + f"page_size={page_size} num_pages={num_blocks} " + f"(overflow at page {overflow_page}): {msg}" + ), )