Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions hopper/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap);
static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap);
static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS);
static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS);
// Consumer Blackwell parts expose ~100KB of shared memory, so reduce the forward pipeline depth
// to keep the shared memory footprint under the device limit while retaining the SM90 depth on
// H100-class GPUs.
static constexpr int kStages = Arch >= 120 ? 1 : (Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS));
static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS);

using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
Expand Down Expand Up @@ -190,7 +193,10 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {

dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
dim3 block_dims = AttnKernel::get_block_shape();
int smem_size = AttnKernel::SharedStorageSize;
static constexpr int kSmemSize = AttnKernel::SharedStorageSize;
static_assert(Arch < 120 || kSmemSize <= 101376,
"SM120 forward kernel requires more shared memory than the consumer budget; reduce tile sizes.");
int smem_size = kSmemSize;
int max_threads_per_block = 0;
int smem_limit_optin = 0;
int smem_limit = 0;
Expand Down
76 changes: 68 additions & 8 deletions hopper/tile_size.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,97 @@

#include <tuple>

constexpr int smem_estimate_bytes(int block_m, int block_n, int headdim, int headdim_v, int element_size) {
// Double-buffer the residency for Q/K/V and the accumulators to reflect the large SMEM footprint observed in practice.
return 2 * (block_m + block_n) * (headdim + headdim_v) * element_size;
}

constexpr int clamp_block_n_for_smem(int block_m, int block_n, int headdim, int headdim_v,
int element_size, int smem_limit) {
int const smem_usage = smem_estimate_bytes(block_m, block_n, headdim, headdim_v, element_size);
if (smem_usage <= smem_limit) {
return block_n;
}
// Keep the tile width aligned to 8 to match the granularity of our block shapes while allowing tight caps.
int const denom = 2 * element_size * (headdim + headdim_v);
int max_block_n = denom > 0 ? smem_limit / denom - block_m : block_n;
if (max_block_n < 8) { max_block_n = 8; }
max_block_n = (max_block_n / 8) * 8;
return max_block_n > 0 ? max_block_n : 8;
}

constexpr std::tuple<int, int> enforce_smem_limit(int block_m, int block_n, int headdim, int headdim_v,
int element_size, int smem_limit) {
int adjusted_block_n = clamp_block_n_for_smem(block_m, block_n, headdim, headdim_v, element_size, smem_limit);
int smem_usage = smem_estimate_bytes(block_m, adjusted_block_n, headdim, headdim_v, element_size);
if (smem_usage > smem_limit && block_m > 64) {
block_m = 64;
adjusted_block_n = clamp_block_n_for_smem(block_m, adjusted_block_n, headdim, headdim_v, element_size, smem_limit);
smem_usage = smem_estimate_bytes(block_m, adjusted_block_n, headdim, headdim_v, element_size);
}
if (smem_usage > smem_limit) {
int const denom = 2 * element_size * (headdim + headdim_v);
int max_block_m = denom > 0 ? smem_limit / denom - adjusted_block_n : block_m;
if (max_block_m < 8) { max_block_m = 8; }
max_block_m = (max_block_m / 8) * 8;
block_m = max_block_m > 0 ? max_block_m : 8;
adjusted_block_n = clamp_block_n_for_smem(block_m, adjusted_block_n, headdim, headdim_v, element_size, smem_limit);
}
return {block_m, adjusted_block_n};
}

// Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap}
constexpr std::tuple<int, int, bool, bool> tile_size_fwd_sm90(
int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2,
bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false) {
constexpr int kSm120ConsumerSmemLimit = 101376;
if (element_size == 2) {
if (headdim <= 64) {
// return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim};
// With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why
// https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131
if (headdim_v == 512) {
return {64, 64, false, false};
// Keep the tile narrow to avoid blowing past the consumer shared-memory budget when values are very wide.
auto const [block_m, block_n] = enforce_smem_limit(64, 64, headdim, headdim_v, element_size, kSm120ConsumerSmemLimit);
return {block_m, block_n, false, false};
} else if (headdim_v == 256) {
return {128, 96, true, false};
auto const [block_m, block_n] = enforce_smem_limit(64, 80, headdim, headdim_v, element_size, kSm120ConsumerSmemLimit);
return {block_m, block_n, true, true};
} else {
// Switch to tile size 192 x 192 for now
bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA;
return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true};
auto const [block_m, block_n] = enforce_smem_limit(192, use_blockN_128 ? 128 : 192, headdim, headdim_v, element_size, kSm120ConsumerSmemLimit);
return {block_m, block_n, use_blockN_128, true};
}
// Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen
// return {192, is_causal || is_local ? 192 : 176, true, false};
} else if (headdim <= 96) {
return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true};
// Large value dimensions inflate smem usage even at modest head sizes, so bias toward smaller tiles for dv >= 256.
int const block_n = headdim_v >= 256 ? 96 : (is_local || paged_kv_non_TMA ? 128 : 144);
auto const [block_m, block_n_capped] = enforce_smem_limit(block_n == 96 ? 128 : 192, block_n, headdim, headdim_v, element_size, kSm120ConsumerSmemLimit);
return {block_m, block_n_capped, false, true};
} else if (headdim <= 128) {
bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA;
return {128, use_blockN_128 ? 128 : 176, true, true};
// Shared memory on consumer parts tops out at ~100KB, so prefer a BlockM=64 path that stays under that limit while
// keeping BlockN as large as possible for throughput.
int const block_n = paged_kv_non_TMA || is_local ? 80 : (headdim_v <= 128 ? 96 : 80);
auto const [block_m, block_n_capped] = enforce_smem_limit(64, block_n, headdim, headdim_v, element_size, kSm120ConsumerSmemLimit);
return {block_m, block_n_capped, true, true};
// {128, 192, true, false} and {192, 128, false, true} are quite good too
// 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS
} else if (headdim <= 192) {
return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem
// The 128x128 / 128x112 tiles exceed the ~100KB shared memory limit of consumer GPUs (for example, when running on
// devices without the larger H100 shared memory carve‑out). Use smaller tiles for all value dims to guarantee we
// stay below the per-block cap across head dimensions up to 192.
int const block_n = paged_kv_non_TMA || is_local ? 64 : (headdim <= 160 ? 80 : 64);
auto const [block_m, block_n_capped] = enforce_smem_limit(64, block_n, headdim, headdim_v, element_size, kSm120ConsumerSmemLimit);
return {block_m, block_n_capped, true, true};
} else {
return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem
// For head dims above 192 the shared-memory footprint grows quickly with BlockM, so stick to 64xN tiles even though
// they are smaller than the H100-optimized 128xN shapes. Favor narrower BlockN when value dims are large to stay
// under the ~100KB cap on consumer GPUs.
int const block_n = paged_kv_non_TMA || is_local ? 48 : (headdim <= 256 ? 64 : 48);
auto const [block_m, block_n_capped] = enforce_smem_limit(64, block_n, headdim, headdim_v, element_size, kSm120ConsumerSmemLimit);
return {block_m, block_n_capped, true, true};
}
} else {
if (headdim <= 64) {
Expand Down
103 changes: 103 additions & 0 deletions tests/hopper/test_tile_size_shared_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import ctypes
import os
import subprocess
import sys
import tempfile
from pathlib import Path


REPO_ROOT = Path(__file__).resolve().parents[2]
SMEM_LIMIT_BYTES = 101_376


def _build_tile_size_bridge(tmpdir: Path) -> Path:
src = tmpdir / "tile_size_bridge.cpp"
lib = tmpdir / "libtile_size_bridge.so"
src.write_text(
r'''
#include <tuple>
#include "hopper/tile_size.h"

extern "C" void tile_size_fwd_sm90_bridge(
int headdim, int headdim_v, bool is_causal, bool is_local, int element_size,
bool v_colmajor, bool paged_kv_non_TMA, bool softcap,
int* block_m, int* block_n) {
auto result = tile_size_fwd_sm90(headdim, headdim_v, is_causal, is_local,
element_size, v_colmajor, paged_kv_non_TMA, softcap);
*block_m = std::get<0>(result);
*block_n = std::get<1>(result);
}
''',
encoding="utf-8",
)
compile_cmd = [
"g++",
"-std=c++20",
"-fPIC",
"-shared",
"-O2",
f"-I{REPO_ROOT}",
str(src),
"-o",
str(lib),
]
subprocess.run(compile_cmd, check=True)
return lib


def _load_bridge() -> ctypes.CDLL:
with tempfile.TemporaryDirectory() as td:
lib = _build_tile_size_bridge(Path(td))
return ctypes.CDLL(str(lib))


def estimate_smem_bytes(block_m: int, block_n: int, headdim: int, headdim_v: int, element_size: int) -> int:
# Mirror the double-buffer estimate used in hopper/tile_size.h.
return 2 * (block_m + block_n) * (headdim + headdim_v) * element_size


def test_tile_sizes_stay_within_blackwell_smem_budget():
bridge = _load_bridge()
bridge.tile_size_fwd_sm90_bridge.argtypes = [
ctypes.c_int,
ctypes.c_int,
ctypes.c_bool,
ctypes.c_bool,
ctypes.c_int,
ctypes.c_bool,
ctypes.c_bool,
ctypes.c_bool,
ctypes.POINTER(ctypes.c_int),
ctypes.POINTER(ctypes.c_int),
]

head_dims = (64, 96, 128, 160, 192, 256, 320)
value_dims = (64, 96, 128, 160, 192, 256, 512)
bools = (False, True)

for headdim in head_dims:
for headdim_v in value_dims:
for is_causal in bools:
for is_local in bools:
if is_causal and is_local:
continue # invalid combination
for paged_kv_non_tma in bools:
block_m = ctypes.c_int()
block_n = ctypes.c_int()
bridge.tile_size_fwd_sm90_bridge(
headdim,
headdim_v,
is_causal,
is_local,
2, # fp16/bf16 element size
False, # v_colmajor
paged_kv_non_tma,
False, # softcap
ctypes.byref(block_m),
ctypes.byref(block_n),
)
smem_bytes = estimate_smem_bytes(block_m.value, block_n.value, headdim, headdim_v, 2)
assert smem_bytes <= SMEM_LIMIT_BYTES, (
f"SMEM overrun for d={headdim}, dv={headdim_v}, causal={is_causal}, "
f"local={is_local}, paged={paged_kv_non_tma}: {smem_bytes}B"
)