From 952c923ce10b3219353d5a621eb0d6e5c1bffef8 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Mon, 27 Apr 2026 16:03:25 -0700 Subject: [PATCH 1/2] fix: use aiter mhc device fix for dsv4 atom --- .github/configs/amd-master.yaml | 18 +-- .../single_node/dsv4_fp4_mi355x_atom.sh | 120 +++++++++++++++--- perf-changelog.yaml | 9 ++ 3 files changed, 114 insertions(+), 33 deletions(-) diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index 1c431427e..012d0828d 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -1491,12 +1491,12 @@ dsv4-fp8-mi355x-sglang: - { tp: 8, conc-start: 4, conc-end: 64 } # Day-0 single-sequence marker for DeepSeek-V4 on ATOM (ROCm/ATOM#650). -# PR1 of the ATOM DSv4 series — single-sequence only (kv_cache[:1,...] -# hardcode), --enforce-eager required, ATOM_USE_TRITON_MOE=1 required on -# gfx950. Image is the standard atom0.1.2.post MI355X base (matching -# qwen3.5-fp8-mi355x-atom); the DSv4 PR is overlaid at runtime by -# benchmarks/single_node/dsv4_fp4_mi355x_atom.sh at a pinned SHA. Sweep -# will expand once ATOM PR3 (multi-request) and PR4 (CUDAGraph) land. +# PR1 of the ATOM DSv4 series still uses torch sparse-attention fallbacks +# that OOM once warmup/prefill batches multiple requests; keep CONC=1 until +# the AITER sparse-attention kernel / multi-request path lands upstream. +# --enforce-eager and ATOM_USE_TRITON_MOE=1 are required on gfx950. Image is +# the standard atom0.1.2.post MI355X base (matching qwen3.5-fp8-mi355x-atom); +# the DSv4 PR is overlaid at runtime by dsv4_fp4_mi355x_atom.sh at a pinned SHA. dsv4-fp4-mi355x-atom: image: rocm/atom:rocm7.2.2_ubuntu24.04_py3.12_pytorch_release_2.10.0_atom0.1.2.post model: deepseek-ai/DeepSeek-V4-Pro @@ -1510,13 +1510,7 @@ dsv4-fp4-mi355x-atom: osl: 1024 search-space: - { tp: 8, ep: 1, conc-start: 1, conc-end: 1 } - - { tp: 8, ep: 1, conc-start: 4, conc-end: 4 } - - { tp: 8, ep: 1, conc-start: 16, conc-end: 16 } - - { tp: 8, ep: 1, conc-start: 32, conc-end: 32 } - isl: 8192 osl: 1024 search-space: - { tp: 8, ep: 1, conc-start: 1, conc-end: 1 } - - { tp: 8, ep: 1, conc-start: 4, conc-end: 4 } - - { tp: 8, ep: 1, conc-start: 16, conc-end: 16 } - - { tp: 8, ep: 1, conc-start: 32, conc-end: 32 } diff --git a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh index 88b5f9580..f490a9112 100644 --- a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh @@ -19,13 +19,17 @@ fi echo "TP: $TP, CONC: $CONC, ISL: $ISL, OSL: $OSL, EP_SIZE: $EP_SIZE" -# EP_SIZE > 1 is still unvalidated by PR #650's repro (offline TP=8 EP=1 -# only). Keep the EP guard. The CONC guard was relaxed to empirically -# probe whether kv_cache[:1,...] in deepseek_v4.py actually corrupts at -# batch>1 in the server path: max-num-seqs=4 caps the running batch -# below the YAML's max conc (32), and per-sequence eval correctness will -# tell us if the hardcode bites. If gsm8k accuracy collapses at conc>1, -# put `if [ "$CONC" -ne 1 ]; then exit 1` back. +# ROCm/ATOM#650 is still a single-request marker for DSv4. Run +# 24953107645 showed CONC>1 fails in two ways: 1k warmup can exhaust the KV +# budget after sparse-attn temporaries raise peak memory, and 8k prefill OOMs +# in the torch sparse_attn fallback when two long requests are batched. Keep +# this fatal guard until ATOM lands the AITER sparse-attention / multi-request +# path for DeepSeek-V4. +if [ "$CONC" -ne 1 ]; then + echo "FATAL: ROCm/ATOM#650 DSv4 path is single-request only; CONC must be 1, got $CONC" >&2 + exit 1 +fi + if [ "$EP_SIZE" -ne 1 ]; then echo "FATAL: ROCm/ATOM#650 PR1 has not validated expert parallel serving; EP_SIZE must be 1, got $EP_SIZE" >&2 exit 1 @@ -43,6 +47,89 @@ export OMP_NUM_THREADS=1 export ATOM_USE_TRITON_MOE=1 export AITER_LOG_LEVEL=WARNING +# Apply the pure-Python part of ROCm/aiter#2916 over the image's installed +# aiter package. Rebuilding aiter inside the benchmark would churn compiled +# ROCm kernels and make the run noisy; the upstream fix only changes +# aiter/ops/mhc.py so mhc_pre intermediate tensors allocate on +# residual.device instead of the global default device. +export AITER_MHC_FIX_SHA="76ea1ed5b2a5f8176ed7a16b1640dd972546a925" +python3 - <<'PYEOF' +import importlib.util +import os +import sys +from pathlib import Path + +required_snippets = [ + " device = residual.device\n out_pad = torch.empty(", + "selected_splitk, m, (hc_mult3 + 31) // 32 * 32, dtype=dtypes.fp32, device=device", + "sqrsum = torch.empty(selected_splitk, m, dtype=dtypes.fp32, device=device)", + "post_mix = torch.empty(m, hc_mult, 1, dtype=dtypes.fp32, device=device)", + "comb_mix = torch.empty(m, hc_mult, hc_mult, dtype=dtypes.fp32, device=device)", + "layer_input = torch.empty(m, hidden_size, dtype=dtypes.bf16, device=device)", +] + +spec = importlib.util.find_spec("aiter.ops.mhc") +if spec is None or spec.origin is None: + sys.exit("FATAL: cannot locate installed aiter.ops.mhc for ROCm/aiter#2916 patch") + +mhc_path = Path(spec.origin) +source = mhc_path.read_text() + +if all(snippet in source for snippet in required_snippets): + print(f"aiter mhc device patch already present: {mhc_path}") + sys.exit(0) + +replacements = [ + ( + " out_pad = torch.empty(\n" + " selected_splitk, m, (hc_mult3 + 31) // 32 * 32, dtype=dtypes.fp32\n" + " )", + " device = residual.device\n" + " out_pad = torch.empty(\n" + " selected_splitk, m, (hc_mult3 + 31) // 32 * 32, dtype=dtypes.fp32, device=device\n" + " )", + ), + ( + " sqrsum = torch.empty(selected_splitk, m, dtype=dtypes.fp32)", + " sqrsum = torch.empty(selected_splitk, m, dtype=dtypes.fp32, device=device)", + ), + ( + " post_mix = torch.empty(m, hc_mult, 1, dtype=dtypes.fp32)", + " post_mix = torch.empty(m, hc_mult, 1, dtype=dtypes.fp32, device=device)", + ), + ( + " comb_mix = torch.empty(m, hc_mult, hc_mult, dtype=dtypes.fp32)", + " comb_mix = torch.empty(m, hc_mult, hc_mult, dtype=dtypes.fp32, device=device)", + ), + ( + " layer_input = torch.empty(m, hidden_size, dtype=dtypes.bf16)", + " layer_input = torch.empty(m, hidden_size, dtype=dtypes.bf16, device=device)", + ), +] + +missing = [old for old, _ in replacements if old not in source] +if missing: + sys.exit( + f"FATAL: {mhc_path} does not match the expected pre-ROCm/aiter#2916 " + f"source; refusing to patch mhc_pre blindly. Missing patterns: " + f"{[m.splitlines()[0].strip() for m in missing]}" + ) + +patched = source +for old, new in replacements: + patched = patched.replace(old, new, 1) + +mhc_path.write_text(patched) +patched_source = mhc_path.read_text() +if not all(snippet in patched_source for snippet in required_snippets): + sys.exit(f"FATAL: ROCm/aiter#2916 mhc device patch failed verification for {mhc_path}") + +print( + f"applied ROCm/aiter#2916 ({os.environ['AITER_MHC_FIX_SHA']}) " + f"mhc device patch: {mhc_path}" +) +PYEOF + # Apply ROCm/ATOM#650 (DSv4 PR1 skeleton) over the image's wheel-installed # atom. The chosen base image ships atom as a built wheel, not editable, so # we overlay an editable install from the PR branch at a pinned SHA. Bump @@ -63,20 +150,11 @@ fi git checkout --force "$ATOM_PR_SHA" test "$(git rev-parse HEAD)" = "$ATOM_PR_SHA" - # WORKAROUND: PR #650 has no env-var toggle to disable the aiter MHC - # kernels, and on this image aiter's `mhc_pre_big_fuse` crashes with a - # HIPGuardImplMasqueradingAsCUDA INTERNAL ASSERT the first time the - # model executes the hc_pre path during prefill (a HIP/CUDA device-type - # mismatch inside aiter, not something we can fix from outside). SGLang's - # DSv4 recipe disables the same family explicitly - # (SGLANG_OPT_USE_TILELANG_MHC_PRE/POST=false, _DEEPGEMM_HC_PRENORM=false). - # Force only `mhc_pre` to torch-fallback; leave `mhc_post` on the aiter - # path since the crash stack only implicated mhc_pre and we'd like to - # recover the perf of half the MHC pipeline. If mhc_post crashes too on - # the next run, add the second sed back. - sed -i 's|mhc_pre = getattr(_aiter, "mhc_pre", None)|mhc_pre = None # patched out (HIP device-guard crash)|' atom/models/deepseek_v4.py - grep -c "patched out" atom/models/deepseek_v4.py | grep -q '^1$' \ - || { echo "FATAL: mhc_pre sed patch did not apply"; exit 1; } + # ROCm/aiter#2916 keeps ATOM's mhc_pre fast path usable. Fail if the + # pinned ATOM checkout no longer exposes that aiter hook; silently + # disabling it would hide the regression this benchmark is meant to catch. + grep -q 'mhc_pre = getattr(_aiter, "mhc_pre", None)' atom/models/deepseek_v4.py \ + || { echo "FATAL: ATOM DSv4 mhc_pre aiter hook not found"; exit 1; } # --no-deps: don't churn the image's pinned ROCm/torch/triton/aiter. # --force-reinstall: replace the wheel-installed atom with the editable copy. diff --git a/perf-changelog.yaml b/perf-changelog.yaml index a29c278f2..594522520 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -1918,3 +1918,12 @@ - "Three CONC bands: A=TP8 (1-8), B=TP4 (16-128), C=DP4 dp-attn (64-512); B/C overlap at conc 64,128" - "Configs: 1k1k and 8k1k, no validation.py / launcher / yaml-field changes (knob-free)" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1180 + +- config-keys: + - dsv4-fp4-mi355x-atom + description: + - "Use ROCm/aiter#2916 mhc_pre device-allocation fix instead of disabling ATOM mhc_pre" + - "Patch installed aiter/ops/mhc.py at runtime to allocate mhc_pre intermediates on residual.device, preserving the aiter MHC fast path without rebuilding aiter" + - "Remove the ATOM deepseek_v4.py sed workaround that forced mhc_pre to torch fallback" + - "Keep dsv4-fp4-mi355x-atom at CONC=1 only; run 24953107645 showed high-concurrency DSv4 ATOM OOMs in PR #650 torch sparse-attention fallbacks before upstream AITER sparse-attention support lands" + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1202 From 0d94067f6fac664a73b976d125c2db5556d4a5d4 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Tue, 28 Apr 2026 00:44:45 -0700 Subject: [PATCH 2/2] fix: restore DSv4 ATOM aiter mhc + perf stack to CI-proven state Matches the exact tree from 55fd191a (run 25027405568). Co-Authored-By: Claude Opus 4.6 --- .../single_node/dsv4_fp4_mi355x_atom.sh | 235 +++++++++++++++++- 1 file changed, 222 insertions(+), 13 deletions(-) diff --git a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh index f490a9112..21708ba1d 100644 --- a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh @@ -40,18 +40,150 @@ PORT=${PORT:-8888} export OMP_NUM_THREADS=1 -# DSv4-specific ATOM env vars (from ROCm/ATOM#650 repro command). -# The aiter fused_moe path is broken on gfx950 with a16w4+Swiglu, so PR1 -# requires the triton matmul_ogs path. AITER_LOG_LEVEL quiets the noisy -# warmup logs that otherwise drown out the server-ready signal. -export ATOM_USE_TRITON_MOE=1 +# DSv4-specific ATOM env vars. Prefer the native AITER MXFP4 MoE path after +# overlaying the AITER perf stack below. Set AITER_DSV4_FP4_MOE_BACKEND=triton +# to return to ROCm/ATOM#650's original triton_kernels matmul_ogs path. +if [ "${AITER_DSV4_PERF_STACK:-1}" = "1" ]; then + DEFAULT_AITER_DSV4_FP4_MOE_BACKEND=native +else + DEFAULT_AITER_DSV4_FP4_MOE_BACKEND=triton +fi +AITER_DSV4_FP4_MOE_BACKEND=${AITER_DSV4_FP4_MOE_BACKEND:-$DEFAULT_AITER_DSV4_FP4_MOE_BACKEND} +if [ "$AITER_DSV4_FP4_MOE_BACKEND" = "triton" ]; then + export ATOM_USE_TRITON_MOE=1 +else + unset ATOM_USE_TRITON_MOE + unset ATOM_USE_TRITON_GEMM +fi export AITER_LOG_LEVEL=WARNING -# Apply the pure-Python part of ROCm/aiter#2916 over the image's installed -# aiter package. Rebuilding aiter inside the benchmark would churn compiled -# ROCm kernels and make the run noisy; the upstream fix only changes -# aiter/ops/mhc.py so mhc_pre intermediate tensors allocate on -# residual.device instead of the global default device. +# Pull in the AITER pieces that matter for DSv4 FP4 on MI355X: +# * origin/main@dde1703e includes ROCm/aiter#2770 a16w4 MoE support. +# * ROCm/aiter#2822 speeds up batched MXFP4 GEMM on gfx950. +# * ROCm/aiter#2900 fixes MXFP4 scale padding for non-256 K. +# * ROCm/aiter#2642 enables/fixes TP=4/8 MXFP4 MoE dispatch. +# * sunway513/aiter@e450e4d adds DSv4 FP4 MoE tuned rows that route +# eligible token counts to FlyDSL FP4 MoE kernels instead of default CK +# heuristics when the image has the optional flydsl package. +# +# ROCm/aiter#2916 is intentionally not cherry-picked here. That PR branch is +# based on a divergent fork and can conflict in unrelated test files; the +# narrow mhc_pre device fix is applied directly to installed aiter below. +# The non-mHC PRs cherry-pick cleanly over the pinned main SHA as of 2026-04-27. +# Keep this as a runtime overlay until AMD publishes an ATOM image with these +# AITER changes baked in; then remove this block and pin that image instead. +if [ "${AITER_DSV4_PERF_STACK:-1}" = "1" ]; then + AITER_PERF_REPO=${AITER_PERF_REPO:-https://github.com/ROCm/aiter.git} + AITER_PERF_DIR=${AITER_PERF_DIR:-/tmp/aiter-dsv4-fp4-perf} + AITER_PERF_BASE_SHA=${AITER_PERF_BASE_SHA:-dde1703ebfc35d3724e07fc4e6e824023063494c} + AITER_PERF_PATCH_REFS=( + "${AITER_PERF_BATCHED_FP4_REF:-pull/2822/head}" + "${AITER_PERF_MXFP4_SCALE_REF:-pull/2900/head}" + "${AITER_PERF_MOE_REF:-pull/2642/head}" + ) + AITER_DSV4_TUNED_FMOE=${AITER_DSV4_TUNED_FMOE:-1} + AITER_DSV4_TUNED_FMOE_REPO=${AITER_DSV4_TUNED_FMOE_REPO:-https://github.com/sunway513/aiter.git} + AITER_DSV4_TUNED_FMOE_SHA=${AITER_DSV4_TUNED_FMOE_SHA:-e450e4deb992c5ecd9db5ef5ef79f1d40208bc9c} + AITER_DSV4_TUNED_FMOE_PATH=${AITER_DSV4_TUNED_FMOE_PATH:-aiter/configs/model_configs/dsv4_fp4_tuned_fmoe.csv} + + rm -rf "$AITER_PERF_DIR" + git clone --filter=blob:none "$AITER_PERF_REPO" "$AITER_PERF_DIR" + ( + cd "$AITER_PERF_DIR" + git fetch --depth=1 origin "$AITER_PERF_BASE_SHA" + git checkout --force "$AITER_PERF_BASE_SHA" + test "$(git rev-parse HEAD)" = "$AITER_PERF_BASE_SHA" + + for ref in "${AITER_PERF_PATCH_REFS[@]}"; do + # Do not use --depth=1 here. A shallow PR-head fetch hides the + # parent commit and makes git treat the cherry-pick as add/add + # conflicts across unrelated files. + git fetch origin "$ref" + git cherry-pick --no-commit FETCH_HEAD + done + + if [ "$AITER_DSV4_TUNED_FMOE" = "1" ]; then + mkdir -p "$(dirname "$AITER_DSV4_TUNED_FMOE_PATH")" + git fetch --depth=1 "$AITER_DSV4_TUNED_FMOE_REPO" "$AITER_DSV4_TUNED_FMOE_SHA" + test "$(git rev-parse FETCH_HEAD)" = "$AITER_DSV4_TUNED_FMOE_SHA" + git show "FETCH_HEAD:$AITER_DSV4_TUNED_FMOE_PATH" > "$AITER_DSV4_TUNED_FMOE_PATH" + grep -q '7168,512,385,6,ActivationType.Silu' "$AITER_DSV4_TUNED_FMOE_PATH" \ + || { echo "FATAL: DSv4 FP4 tuned fMoE rows not found in $AITER_DSV4_TUNED_FMOE_PATH"; exit 1; } + fi + + if [ ! -d 3rdparty/composable_kernel/include ]; then + git submodule update --init --recursive --depth=1 3rdparty/composable_kernel \ + || git submodule update --init --recursive 3rdparty/composable_kernel + fi + + PREBUILD_KERNELS=${AITER_PREBUILD_KERNELS:-0} \ + python3 -m pip install --no-deps --no-build-isolation --force-reinstall -e . + ) + + if [ "$AITER_DSV4_TUNED_FMOE" = "1" ]; then + export AITER_DSV4_TUNED_FMOE_FILE="$AITER_PERF_DIR/$AITER_DSV4_TUNED_FMOE_PATH" + fi + if [ "$AITER_DSV4_TUNED_FMOE" = "1" ] && [ -z "${AITER_CONFIG_FMOE:-}" ]; then + export AITER_CONFIG_FMOE="$AITER_PERF_DIR/aiter/configs/tuned_fmoe.csv:$AITER_DSV4_TUNED_FMOE_FILE" + fi + + python3 - <<'PYEOF' +import importlib.util +import csv +import os +from pathlib import Path +import aiter + +root = Path(aiter.__file__).resolve().parent +moe = (root / "fused_moe.py").read_text() +fp4_utils = (root / "utility" / "fp4_utils.py").read_text() +dsv4_tuned_fmoe = Path(os.environ["AITER_DSV4_TUNED_FMOE_FILE"]) if os.environ.get("AITER_DSV4_TUNED_FMOE_FILE") else None +required = { + "native MXFP4 MoE skip_inter_quant": "skip_inter_quant" in moe, + "MXFP4 scaleN_pad fix": "scaleN_pad" in fp4_utils, + "DSv4 FP4 tuned fMoE config": dsv4_tuned_fmoe is None or dsv4_tuned_fmoe.exists(), +} +missing = [name for name, ok in required.items() if not ok] +if missing: + raise SystemExit(f"FATAL: AITER DSv4 perf stack verification failed: {missing}") + +if dsv4_tuned_fmoe is not None and dsv4_tuned_fmoe.exists(): + config_paths = os.environ.get("AITER_CONFIG_FMOE", "").split(":") + if str(dsv4_tuned_fmoe) not in config_paths: + print( + "WARN: AITER_CONFIG_FMOE was user-supplied and does not include " + f"{dsv4_tuned_fmoe}; DSv4 tuned fMoE rows may not be active." + ) + try: + from aiter.ops.flydsl import is_flydsl_available + except Exception as exc: + print(f"aiter DSv4 tuned fMoE installed; FlyDSL availability check failed: {exc!r}") + else: + flydsl_available = is_flydsl_available() + print(f"aiter FlyDSL available: {flydsl_available}") + if flydsl_available: + from aiter.ops.flydsl.moe_kernels import get_flydsl_kernel_params + + missing_kernels = set() + with dsv4_tuned_fmoe.open(newline="") as handle: + for row in csv.DictReader(handle): + for name in (row.get("kernelName1", ""), row.get("kernelName2", "")): + if name.startswith("flydsl_") and get_flydsl_kernel_params(name) is None: + missing_kernels.add(name) + if missing_kernels: + raise SystemExit( + "FATAL: DSv4 FP4 tuned fMoE references missing FlyDSL kernels: " + f"{sorted(missing_kernels)[:5]}" + ) +print(f"aiter DSv4 perf stack imported from: {root}") +PYEOF +else + echo "WARN: AITER_DSV4_PERF_STACK=0; using image-provided aiter" +fi + +# Ensure the pure-Python part of ROCm/aiter#2916 is present. The AITER perf +# stack above already includes it; this block is kept as a fallback for +# AITER_DSV4_PERF_STACK=0 or future images that ship aiter without the fix. export AITER_MHC_FIX_SHA="76ea1ed5b2a5f8176ed7a16b1640dd972546a925" python3 - <<'PYEOF' import importlib.util @@ -156,6 +288,78 @@ fi grep -q 'mhc_pre = getattr(_aiter, "mhc_pre", None)' atom/models/deepseek_v4.py \ || { echo "FATAL: ATOM DSv4 mhc_pre aiter hook not found"; exit 1; } + # ROCm/ATOM#650 sparse_attn_v4.py is a correctness-first torch fallback. + # Add two local mitigations while we wait for a serving-compatible AITER + # sparse-attention kernel: + # 1. chunk prefill over the M dimension to keep temporary scores under + # memory pressure, making higher-conc experiments less likely to OOM; + # 2. use a B=1,M=1 decode fast path that avoids the fallback's large + # broadcast/mask/concat intermediates on every generated token. + python3 - <<'PYEOF' +from pathlib import Path + +path = Path("atom/model_ops/sparse_attn_v4.py") +source = path.read_text() +marker = "ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS" +if marker not in source: + source = source.replace( + "from typing import Tuple\n\nimport torch\n", + "from typing import Tuple\n\nimport os\n\nimport torch\n", + 1, + ) + old = """ out_dtype = q.dtype + device = q.device + + # ----- Gather KV per query position ----- +""" + new = """ out_dtype = q.dtype + device = q.device + + chunk_tokens = int(os.environ.get("ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS", "0") or "0") + if B == 1 and chunk_tokens > 0 and M > chunk_tokens: + return torch.cat( + [ + sparse_attn( + q[:, start : start + chunk_tokens], + kv, + attn_sink, + topk_idxs[:, start : start + chunk_tokens], + softmax_scale, + ) + for start in range(0, M, chunk_tokens) + ], + dim=1, + ) + + if B == 1 and M == 1: + valid_1d = topk_idxs[0, 0] != -1 + if not bool(valid_1d.any()): + return torch.zeros_like(q) + idx_1d = topk_idxs[0, 0] + if bool(valid_1d.all()): + kv_f32 = kv[0].index_select(0, idx_1d.long()).float() + else: + kv_f32 = kv[0].index_select(0, idx_1d[valid_1d].long()).float() + q_f32 = q[0, 0].float() + scores = torch.matmul(q_f32, kv_f32.transpose(0, 1)) * float(softmax_scale) + sink = attn_sink.float().view(H, 1) + cmax = torch.maximum(scores.amax(dim=-1, keepdim=True), sink) + exp_scores = (scores - cmax).exp() + denom = exp_scores.sum(dim=-1, keepdim=True) + (sink - cmax).exp() + out = (exp_scores / denom.clamp(min=1e-30)).matmul(kv_f32) + return out.view(1, 1, H, D).to(out_dtype) + + # ----- Gather KV per query position ----- +""" + if old not in source: + raise SystemExit("FATAL: sparse_attn_v4.py did not match expected PR650 source") + source = source.replace(old, new, 1) + path.write_text(source) + print(f"applied DSv4 sparse_attn_v4 decode/chunk patch: {path}") +else: + print(f"DSv4 sparse_attn_v4 decode/chunk patch already present: {path}") +PYEOF + # --no-deps: don't churn the image's pinned ROCm/torch/triton/aiter. # --force-reinstall: replace the wheel-installed atom with the editable copy. pip install --no-deps --force-reinstall -e . @@ -260,14 +464,16 @@ PYEOF # there. Set 1k1k explicitly; 8k1k retains the existing 10240 cap that's # already running successfully. if [ "$ISL" = "1024" ] && [ "$OSL" = "1024" ]; then - CALCULATED_MAX_MODEL_LEN=" --max-model-len 2304 " + MAX_MODEL_LEN_VALUE=2304 else - CALCULATED_MAX_MODEL_LEN=" --max-model-len 10240 " + MAX_MODEL_LEN_VALUE=10240 fi +CALCULATED_MAX_MODEL_LEN=" --max-model-len $MAX_MODEL_LEN_VALUE " if [ "${EVAL_ONLY}" = "true" ]; then setup_eval_context - CALCULATED_MAX_MODEL_LEN=" --max-model-len $EVAL_MAX_MODEL_LEN " + MAX_MODEL_LEN_VALUE="$EVAL_MAX_MODEL_LEN" + CALCULATED_MAX_MODEL_LEN=" --max-model-len $MAX_MODEL_LEN_VALUE " fi if [ "$EP_SIZE" -gt 1 ]; then @@ -282,6 +488,7 @@ start_gpu_monitor set -x BLOCK_SIZE=${BLOCK_SIZE:-16} +export ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS=${ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS:-256} # --enforce-eager is required: ROCm/ATOM#650 (PR1 skeleton) has no CUDAGraph # support yet (deferred to a follow-up PR). max-num-seqs is sized to the # client concurrency with a floor at 4 — the ATOM default (512) makes the @@ -292,6 +499,7 @@ BLOCK_SIZE=${BLOCK_SIZE:-16} # deepseek_v4.py means any forward with batch>1 silently corrupts # non-slot-0 lanes; eval (gsm8k) at conc>1 is the canary. MAX_NUM_SEQS=$(( CONC < 4 ? 4 : CONC )) +MAX_NUM_BATCHED_TOKENS=${MAX_NUM_BATCHED_TOKENS:-$MAX_MODEL_LEN_VALUE} python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ @@ -300,6 +508,7 @@ python3 -m atom.entrypoints.openai_server \ --block-size $BLOCK_SIZE \ --enforce-eager \ --max-num-seqs $MAX_NUM_SEQS \ + --max-num-batched-tokens $MAX_NUM_BATCHED_TOKENS \ --trust-remote-code > $SERVER_LOG 2>&1 & SERVER_PID=$!