diff --git a/tools/launcher/common/specdec/read_vllm_files.sh b/tools/launcher/common/specdec/read_vllm_files.sh new file mode 100755 index 00000000000..d4cf5729dea --- /dev/null +++ b/tools/launcher/common/specdec/read_vllm_files.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -euo pipefail +echo "=== pattern_matcher.py lines 305-325 ===" +sed -n '305,325p' /usr/local/lib/python3.12/dist-packages/torch/_inductor/pattern_matcher.py 2>/dev/null || echo "NOT FOUND" +echo "=== post_grad.py lines 345-375 ===" +sed -n '345,375p' /usr/local/lib/python3.12/dist-packages/torch/_inductor/fx_passes/post_grad.py 2>/dev/null || echo "NOT FOUND" +echo "=== post_grad.py lines 1240-1260 ===" +sed -n '1240,1260p' /usr/local/lib/python3.12/dist-packages/torch/_inductor/fx_passes/post_grad.py 2>/dev/null || echo "NOT FOUND" +echo "=== DONE ===" diff --git a/tools/launcher/common/specdec/vllm_smoke_test.sh b/tools/launcher/common/specdec/vllm_smoke_test.sh index 4b9d5a63b4f..f46ef7bea00 100644 --- a/tools/launcher/common/specdec/vllm_smoke_test.sh +++ b/tools/launcher/common/specdec/vllm_smoke_test.sh @@ -28,6 +28,28 @@ # VLLM_PORT — server port (default: 8000) # REASONING_PARSER — reasoning parser (e.g., "qwen3" for Qwen3.5) # DISABLE_PREFIX_CACHING — set to "1" to disable prefix caching +# TRUST_REMOTE_CODE — set to "1" to pass --trust-remote-code (needed for custom architectures) +# UPGRADE_TRANSFORMERS — set to "1" to install transformers from HuggingFace main branch +# DATA_PARALLEL_SIZE — data parallel size; mutually exclusive with TP_SIZE (default: unset, uses TP_SIZE) +# KV_CACHE_DTYPE — kv cache dtype (e.g., "fp8"); omitted if unset +# BLOCK_SIZE — paged attention block size (e.g., 256 for DeepSeek V4) +# ENABLE_EXPERT_PARALLEL — set to "1" to pass --enable-expert-parallel +# TOKENIZER_MODE — tokenizer mode (e.g., "deepseek_v4") +# VLLM_EXTRA_ARGS — additional raw args appended verbatim to vllm serve (simple flags only) +# COMPILATION_CONFIG — JSON string for --compilation-config (e.g., for B200 native ops) +# e.g., '{"cudagraph_mode":"FULL_AND_PIECEWISE","custom_ops":["all"]}' +# Passed as a properly-quoted single arg to avoid brace expansion issues. +# NOTE: NeMo Run generates unquoted env var assignments in sbatch scripts, +# so JSON with braces/brackets gets brace-expanded. Use BUILD_COMPILATION_CONFIG +# instead to avoid this — the JSON is constructed safely inside the script. +# BUILD_COMPILATION_CONFIG — alternative to COMPILATION_CONFIG: just pass the cudagraph_mode string +# (e.g., "FULL_AND_PIECEWISE") and the script constructs: +# {"cudagraph_mode":"","custom_ops":["all"]} +# This avoids brace-expansion of JSON in NeMo Run sbatch env var assignments. +# GPU_MEM_UTIL — gpu_memory_utilization fraction (default: unset, vLLM default 0.9) +# MAX_BATCHED_TOKENS — override max_num_batched_tokens (default: 32768) +# COPY_MODEL_TO_TMPFS — set to "1" to copy model to /dev/shm before serving +# (prevents NFS stale-handle errors when 8+ workers mmap weights simultaneously) SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" source ${SCRIPT_DIR}/../service_utils.sh 2>/dev/null || true @@ -35,10 +57,420 @@ source ${SCRIPT_DIR}/../service_utils.sh 2>/dev/null || true # Ensure pandas is available (missing in some vLLM nightly builds) pip install pandas 2>/dev/null || true -cleanup() { kill $SERVER_PID 2>/dev/null; sleep 2; kill -9 $SERVER_PID 2>/dev/null; rm -f "${VLLM_LOG:-}" 2>/dev/null; } +# Raise the per-user process limit so concurrent deepgemm/NVCC JIT workers (one per +# DP rank) don't exhaust the nproc limit when popen(nvcc) is called simultaneously +# during CUDA graph capture warmup. popen() returns nullptr (triggering deepgemm's +# "pipe != nullptr" assertion) when fork() fails with EAGAIN due to nproc limit. +ulimit -u unlimited 2>/dev/null || true + +# Redirect deepgemm/NVCC JIT compilation away from /tmp (too small on B200) and +# /dev/shm (noexec — dlopen of compiled .so fails). DEEPGEMM_TMPDIR must be a +# writable+executable NFS path (e.g., /cicd/deepgemm_tmp). We use a separate env var +# so enroot doesn't pick it up at container startup (enroot reads TMPDIR before the +# container starts, so setting TMPDIR in sbatch would break the container launch). +if [ -n "${DEEPGEMM_TMPDIR:-}" ]; then + mkdir -p "$DEEPGEMM_TMPDIR" + export TMPDIR="$DEEPGEMM_TMPDIR" +fi + +# Force torch inductor to use the v2 auto_functionalized algorithm. +# vLLM explicitly sets enable_auto_functionalized_v2=False in its inductor config, +# which causes failures with fallback FP8 ops (e.g., when VLLM_USE_DEEP_GEMM=0): +# 1. v1 decompose pass can't remove auto_functionalized nodes for MXFP4 ops +# 2. Remaining nodes execute as Python wrappers calling ops via stable IValue +# 3. The /opt/venv stable IValue binary doesn't know ScalarType 44 (MXFP4) → crash +# Fix: enable v2 in vLLM's compilation code so the proper decompose pass is used. +# v2 handles in-place mutations generically, removing the Python wrapper path entirely. +# Set FORCE_AF_V2=1 to enable. +if [ "${FORCE_AF_V2:-0}" = "1" ]; then + python3 << 'PYEOF' || true +import inspect, compileall, glob, re, os, site + +# ────────────────────────────────────────────────────────────────────────────── +# The problem: +# vLLM explicitly passes 'enable_auto_functionalized_v2': False in its +# inductor_compile_config dict. This makes the v1 decompose pass run, which +# can't remove auto_functionalized nodes wrapping MXFP4/cutlass ops. Those +# nodes then execute as Python wrappers calling ops via torch._ops.py → stable +# IValue → ScalarType 44 (MXFP4) not registered in /opt/venv binary → crash. +# +# Fix strategy: +# 1. Write a .pth startup file to ALL site-packages dirs so every spawned +# worker process auto-loads a module that monkey-patches +# torch._inductor.config.patch() to strip enable_auto_functionalized_v2=False. +# 2. Patch the source files directly (file glob) with an updated regex that +# handles both bare and quoted-key dict forms. +# 3. Patch post_grad.py assertions as a safety net. +# ────────────────────────────────────────────────────────────────────────────── + +PATCH_MODULE_NAME = 'vllm_force_af_v2_runtime' +PATCH_CODE = r''' +# Auto-loaded via .pth in site-packages. Runs in main process AND every spawned worker. +# Strategy: intercept torch._dynamo.aot_compile (the AOT compile entry point used by vLLM) +# to strip enable_auto_functionalized_v2=False from options before compilation starts. +# Uses sys.modules as sentinel (torch._inductor.config rejects unknown __getattr__). +import sys as _sys + +def _strip_af_v2_false(d): + if isinstance(d, dict) and d.get('enable_auto_functionalized_v2') is False: + d = {k: v for k, v in d.items() if k != 'enable_auto_functionalized_v2'} + print('[force_af_v2] Stripped enable_auto_functionalized_v2=False from inductor options', flush=True) + return d + +def _install(): + if _sys.modules.get('_vllm_af_v2_patched'): + return + _sys.modules['_vllm_af_v2_patched'] = True + + # Patch 1: torch._dynamo.aot_compile (called by vLLM decorators.py) + try: + import torch._dynamo as _dynamo + _orig_aot = _dynamo.aot_compile + def _patched_aot(*args, **kwargs): + if 'options' in kwargs: + kwargs['options'] = _strip_af_v2_false(kwargs['options']) + return _orig_aot(*args, **kwargs) + _dynamo.aot_compile = _patched_aot + print('[force_af_v2] Patched torch._dynamo.aot_compile', flush=True) + except Exception as e: + print(f'[force_af_v2] aot_compile patch failed: {e}', flush=True) + + # Patch 2: torch._dynamo.aot_compile_fullgraph (alternative entry point) + try: + import torch._dynamo.aot_compile as _aot_mod + _orig_fg = _aot_mod.aot_compile_fullgraph + def _patched_fg(*args, **kwargs): + if 'options' in kwargs: + kwargs['options'] = _strip_af_v2_false(kwargs['options']) + return _orig_fg(*args, **kwargs) + _aot_mod.aot_compile_fullgraph = _patched_fg + print('[force_af_v2] Patched torch._dynamo.aot_compile_fullgraph', flush=True) + except Exception as e: + print(f'[force_af_v2] aot_compile_fullgraph patch failed: {e}', flush=True) + + # Patch 3: torch._inductor.config.patch (if it exists in this PyTorch version) + try: + import torch._inductor.config as _ic + _orig_patch = _ic.patch + def _patched_patch(*args, **kwargs): + new_args = (_strip_af_v2_false(args[0]),) + args[1:] if args and isinstance(args[0], dict) else args + if kwargs.get('enable_auto_functionalized_v2') is False: + kwargs = {k: v for k, v in kwargs.items() if k != 'enable_auto_functionalized_v2'} + return _orig_patch(*new_args, **kwargs) + _ic.patch = _patched_patch + print('[force_af_v2] Patched torch._inductor.config.patch', flush=True) + except Exception as e: + print(f'[force_af_v2] config.patch intercept skipped: {e}', flush=True) + + # Patch 4: Set global torch._inductor.config.enable_auto_functionalized_v2 = True. + # This ensures post_grad.py (which reads the global config) uses the v2 decompose path. + try: + import torch._inductor.config as _ic + _ic.enable_auto_functionalized_v2 = True + print('[force_af_v2] Set torch._inductor.config.enable_auto_functionalized_v2 = True', flush=True) + except Exception as e: + print(f'[force_af_v2] inductor global config set failed: {e}', flush=True) + + # Patch 5: torch._inductor.standalone_compile — vLLM's piecewise backend uses this + # (NOT torch._dynamo.aot_compile) to compile each graph segment. Strip the + # enable_auto_functionalized_v2=False override so the global True setting survives. + try: + import torch._inductor as _ti_mod + _orig_sc = getattr(_ti_mod, 'standalone_compile', None) + if _orig_sc is not None: + def _patched_sc(fn, *args, **kwargs): + opts = kwargs.get('options') + if isinstance(opts, dict) and opts.get('enable_auto_functionalized_v2') is False: + kwargs['options'] = {k: v for k, v in opts.items() if k != 'enable_auto_functionalized_v2'} + print('[force_af_v2] Stripped enable_auto_functionalized_v2=False from standalone_compile', flush=True) + return _orig_sc(fn, *args, **kwargs) + _ti_mod.standalone_compile = _patched_sc + print('[force_af_v2] Patched torch._inductor.standalone_compile', flush=True) + else: + print('[force_af_v2] torch._inductor.standalone_compile not found, skipping', flush=True) + except Exception as e: + print(f'[force_af_v2] standalone_compile patch failed: {e}', flush=True) + +_install() +''' + +# Write the patch module + .pth startup file to every site-packages directory +site_dirs = site.getsitepackages() + [site.getusersitepackages()] +for sp in site_dirs: + if not os.path.isdir(sp): + continue + try: + mod_path = os.path.join(sp, f'{PATCH_MODULE_NAME}.py') + pth_path = os.path.join(sp, f'{PATCH_MODULE_NAME}.pth') + with open(mod_path, 'w') as f: + f.write(PATCH_CODE) + with open(pth_path, 'w') as f: + f.write(f'import {PATCH_MODULE_NAME}\n') + print(f'[force_af_v2] Wrote {pth_path} → auto-loads in all worker processes') + except Exception as e: + print(f'[force_af_v2] Could not write to {sp}: {e}') + +# Also run the runtime patch immediately in this process +exec(PATCH_CODE) + +# Step 2: Source-file patch — fix regex to handle quoted-key dict form. +vllm_dirs = [ + '/usr/local/lib/python3.12/dist-packages/vllm', + '/opt/venv/lib/python3.12/site-packages/vllm', +] +for vllm_dir in vllm_dirs: + if not os.path.isdir(vllm_dir): + continue + for py_file in glob.glob(os.path.join(vllm_dir, '**/*.py'), recursive=True): + if '__pycache__' in py_file: + continue + try: + with open(py_file) as f: + content = f.read() + if 'enable_auto_functionalized_v2' not in content: + continue + for i, line in enumerate(content.splitlines()): + if 'enable_auto_functionalized_v2' in line: + print(f'[force_af_v2] Found in {py_file}:{i+1}: {line.strip()}') + # Match both bare and quoted-key dict forms + patched = re.sub( + r'("?enable_auto_functionalized_v2"?\s*[:=]\s*)False', + r'\1True', + content + ) + # Special case: compilation.py stores the key in a KEY constant and uses + # KEY: False in the dict — the literal string search above misses this form. + if '/vllm/config/compilation.py' in py_file or py_file.endswith('/compilation.py'): + patched2 = re.sub(r'\bKEY(\s*:\s*)False', r'KEY\1True', patched) + if patched2 != patched: + patched = patched2 + print(f'[force_af_v2] Patched KEY: False in {py_file}') + if patched != content: + with open(py_file, 'w') as f: + f.write(patched) + compileall.compile_file(py_file, quiet=2, force=True) + print(f'[force_af_v2] Patched source file: {py_file}') + except Exception as e: + print(f'[force_af_v2] Error processing {py_file}: {e}') + +# Step 3: Patch post_grad.py assertions as safety net. +try: + import torch._inductor.fx_passes.post_grad as pg + src_file = inspect.getfile(pg) + with open(src_file) as f: + content = f.read() + patterns = [ + ('raise AssertionError("auto_functionalized was not removed")', + 'pass # PATCHED: v1 nodes skipped (FORCE_AF_V2=1)'), + ('raise AssertionError("auto_functionalized_v2 was not removed")', + 'pass # PATCHED: v2 nodes skipped (FORCE_AF_V2=1)'), + ('if config.enable_auto_functionalized_v2:', 'if True: # PATCHED (FORCE_AF_V2=1)'), + ('if inductor_config.enable_auto_functionalized_v2:', 'if True: # PATCHED (FORCE_AF_V2=1)'), + # Wrap the decompose_triton_kernel_wrapper_functional call in try/except so that a + # node-count mismatch AssertionError (pattern_matcher.py:316) doesn't abort compilation. + # vLLM's Triton kernel wrappers can produce a different graph node count than PyTorch 2.11 + # expects; skipping the decompose pass is safe — kernels execute via the wrapper path. + ('GraphTransformObserver(gm, "decompose_triton_kernel_wrapper_functional").apply_graph_pass(decompose_triton_kernel_wrapper_functional)', + 'try:\n GraphTransformObserver(gm, "decompose_triton_kernel_wrapper_functional").apply_graph_pass(decompose_triton_kernel_wrapper_functional)\n except AssertionError as _af2_e:\n print(f"[force_af_v2] decompose_triton_kernel_wrapper_functional skipped: {_af2_e}", flush=True) # PATCHED'), + ] + patched = content + for old, new in patterns: + if old in patched: + patched = patched.replace(old, new) + print(f'[force_af_v2] post_grad patch: {old[:70]!r}') + if patched != content: + with open(src_file, 'w') as f: + f.write(patched) + compileall.compile_file(src_file, quiet=2, force=True) + print(f'[force_af_v2] Wrote and recompiled {src_file}') +except Exception as e: + print(f'[force_af_v2] post_grad.py patch failed: {e}') + +# Step 4: Patch pattern_matcher.py to remove the node-count assertion fired by +# decompose_triton_kernel_wrapper_functional when vLLM's Triton kernel wrapper graphs +# have a different number of nodes than PyTorch 2.11's replacement graph. +# The assertion at pattern_matcher.py:316 reads: +# assert len(graph_with_eager_vals.graph.nodes) == len(replacement.graph.nodes) +# The comment above it says "might not be true in general" — we exploit this escape hatch. +try: + import re as _re + import torch._inductor.pattern_matcher as pm + pm_file = inspect.getfile(pm) + with open(pm_file) as f: + pm_content = f.read() + pm_patched = _re.sub( + r'assert len\(graph_with_eager_vals\.graph\.nodes\) == len\(\s*\n\s*replacement\.graph\.nodes\s*\n\s*\)', + 'pass # PATCHED: skip node-count assertion for triton_kernel_wrapper_functional (FORCE_AF_V2=1)', + pm_content, + ) + if pm_patched != pm_content: + with open(pm_file, 'w') as f: + f.write(pm_patched) + compileall.compile_file(pm_file, quiet=2, force=True) + print(f'[force_af_v2] Patched pattern_matcher.py node-count assertion: {pm_file}') + else: + print(f'[force_af_v2] pattern_matcher.py: assertion pattern not found in {pm_file}') +except Exception as e: + print(f'[force_af_v2] pattern_matcher.py patch failed: {e}') +PYEOF +fi + +# Patch vllm._custom_ops.cutlass_scaled_mm to cast ue8m0 (ScalarType 44) block-FP8 +# scale tensors to uint8 before dispatching through PyTorch's stable IValue layer. +# +# deepseek_v4_fp8 stores per-block scales in ue8m0 format (unsigned 8-bit, 8 exponent +# bits, 0 mantissa bits). PyTorch 2.11's stableivalue_conversions.h doesn't recognise +# ScalarType 44, so torch.ops._C.cutlass_scaled_mm crashes during the model's dummy +# forward pass (profile_run). Casting to uint8 preserves the raw bytes — the CUTLASS +# kernel receives the same values — while satisfying the stable IValue type check. +# +# The patch is written as a .pth startup module so it propagates to every forked worker. +python3 << 'PYEOF' || true +import os, site + +PATCH_MODULE_NAME = 'vllm_ue8m0_cast_patch' +PATCH_CODE = r''' +import sys as _sys + +def _install(): + if _sys.modules.get('_vllm_ue8m0_patch_installed'): + return + _sys.modules['_vllm_ue8m0_patch_installed'] = True + try: + import torch + import vllm._custom_ops as _vllm_co + + _orig_csm = _vllm_co.cutlass_scaled_mm + + _SAFE_DTYPES = frozenset([ + torch.float32, torch.float16, torch.bfloat16, + torch.uint8, torch.int8, + torch.float8_e4m3fn, torch.float8_e5m2, + ]) + + def _csm_ue8m0_safe(*args, **kwargs): + args = list(args) + def _cast(t): + if t is not None and hasattr(t, 'dtype') and t.dtype not in _SAFE_DTYPES and t.element_size() == 1: + return t.view(torch.uint8) + return t + if 'scale_a' in kwargs: + kwargs['scale_a'] = _cast(kwargs['scale_a']) + elif len(args) > 3: + args[3] = _cast(args[3]) + if 'scale_b' in kwargs: + kwargs['scale_b'] = _cast(kwargs['scale_b']) + elif len(args) > 4: + args[4] = _cast(args[4]) + return _orig_csm(*args, **kwargs) + + _vllm_co.cutlass_scaled_mm = _csm_ue8m0_safe + print('[patch_ue8m0] Patched vllm._custom_ops.cutlass_scaled_mm', flush=True) + except Exception as e: + print(f'[patch_ue8m0] Patch failed: {e}', flush=True) + +_install() +''' + +for sp in site.getsitepackages() + [site.getusersitepackages()]: + if not os.path.isdir(sp): + continue + try: + with open(os.path.join(sp, f'{PATCH_MODULE_NAME}.py'), 'w') as f: + f.write(PATCH_CODE) + with open(os.path.join(sp, f'{PATCH_MODULE_NAME}.pth'), 'w') as f: + f.write(f'import {PATCH_MODULE_NAME}\n') + print(f'[patch_ue8m0] Wrote .pth to {sp}') + break + except Exception as e: + print(f'[patch_ue8m0] Could not write to {sp}: {e}') + +exec(PATCH_CODE) +PYEOF + +# Allow callers to upgrade transformers for models not yet in the container's bundled version +# (e.g. deepseek_v4 requires transformers >= 4.52). Set UPGRADE_TRANSFORMERS=1 to enable. +if [ "${UPGRADE_TRANSFORMERS:-0}" = "1" ]; then + pip install --upgrade --pre transformers 2>/dev/null || true + # Register deepseek_v4 by writing a .pth file + module to site-packages. + # Python processes .pth files at startup, so this propagates to every vLLM subprocess. + python3 << 'PYEOF' || true +import sys, os, sysconfig + +PATCH_MODULE = ''' +try: + from transformers import AutoConfig, PretrainedConfig + class DeepseekV4Config(PretrainedConfig): + model_type = "deepseek_v4" + def __init__(self, **kwargs): + # Pre-populate ALL config.json fields before super().__init__ runs, + # because PretrainedConfig in transformers 5.x accesses attributes + # like max_position_embeddings during initialization. + for k, v in kwargs.items(): + object.__setattr__(self, k, v) + # Override architectures to TransformersForCausalLM so vLLM routes + # through its generic transformers backend (trust_remote_code path), + # since DeepseekV4ForCausalLM is not yet in vLLMs native registry. + object.__setattr__(self, "architectures", ["TransformersForCausalLM"]) + super().__init__(**kwargs) + AutoConfig.register("deepseek_v4", DeepseekV4Config, exist_ok=True) +except Exception: + pass +''' + +site_packages = sysconfig.get_path("purelib") +module_path = os.path.join(site_packages, "_deepseek_v4_patch.py") +pth_path = os.path.join(site_packages, "deepseek_v4.pth") + +with open(module_path, "w") as f: + f.write(PATCH_MODULE) +with open(pth_path, "w") as f: + f.write("import _deepseek_v4_patch\n") + +print(f"[patch] wrote {pth_path} -> will register deepseek_v4 on every Python startup") +PYEOF +fi + +# Apply custom vLLM patches before starting the server. +# Used for models that require container-level modifications not yet upstream. +# Set VLLM_PATCH_SCRIPT to a Python script path (relative to /nemo_run/code/). +if [ -n "${VLLM_PATCH_SCRIPT:-}" ] && [ -f "${VLLM_PATCH_SCRIPT}" ]; then + echo "Applying vLLM patches: ${VLLM_PATCH_SCRIPT}" + python3 "${VLLM_PATCH_SCRIPT}" || { echo "ERROR: patch script failed"; exit 1; } +fi + +TMPFS_MODEL="" +cleanup() { + kill $SERVER_PID 2>/dev/null + sleep 2 + kill -9 $SERVER_PID 2>/dev/null + rm -f "${VLLM_LOG:-}" 2>/dev/null + # Clean up tmpfs copy if we made one + if [ -n "$TMPFS_MODEL" ] && [ -d "$TMPFS_MODEL" ]; then + echo "Removing tmpfs model copy: $TMPFS_MODEL" + rm -rf "$TMPFS_MODEL" + fi +} trap cleanup EXIT MODEL=${HF_MODEL_CKPT} + +# Copy model to /dev/shm to avoid NFS stale-handle errors when many workers mmap weights simultaneously +if [ "${COPY_MODEL_TO_TMPFS:-0}" = "1" ]; then + MODEL_NAME=$(basename "$MODEL") + TMPFS_MODEL="/dev/shm/${MODEL_NAME}" + if [ -d "$TMPFS_MODEL" ] && [ -f "$TMPFS_MODEL/config.json" ]; then + echo "Using existing tmpfs model copy: $TMPFS_MODEL" + else + MODEL_SIZE=$(du -sh "$MODEL" 2>/dev/null | cut -f1 || echo "?") + AVAIL_SHM=$(df -h /dev/shm 2>/dev/null | tail -1 | awk '{print $4}' || echo "?") + echo "Copying model to /dev/shm (${MODEL_SIZE}, available: ${AVAIL_SHM})..." + cp -r "$MODEL" "$TMPFS_MODEL" + echo "Model copy done: $TMPFS_MODEL" + fi + MODEL="$TMPFS_MODEL" + echo "Loading from tmpfs: $MODEL" +fi DRAFT=${DRAFT_MODEL:-} # Auto-detect exported checkpoint from training output dir if [ -z "$DRAFT" ] && [ -n "${DRAFT_CKPT_DIR:-}" ]; then @@ -74,30 +506,64 @@ fi if [ "${DISABLE_PREFIX_CACHING:-}" = "1" ]; then OPTIONAL_ARGS="${OPTIONAL_ARGS} --no-enable-prefix-caching" fi +if [ "${TRUST_REMOTE_CODE:-}" = "1" ]; then + OPTIONAL_ARGS="${OPTIONAL_ARGS} --trust-remote-code" +fi +if [ -n "${KV_CACHE_DTYPE:-}" ]; then + OPTIONAL_ARGS="${OPTIONAL_ARGS} --kv-cache-dtype ${KV_CACHE_DTYPE}" +fi +if [ -n "${BLOCK_SIZE:-}" ]; then + OPTIONAL_ARGS="${OPTIONAL_ARGS} --block-size ${BLOCK_SIZE}" +fi +if [ "${ENABLE_EXPERT_PARALLEL:-}" = "1" ]; then + OPTIONAL_ARGS="${OPTIONAL_ARGS} --enable-expert-parallel" +fi +if [ -n "${TOKENIZER_MODE:-}" ]; then + OPTIONAL_ARGS="${OPTIONAL_ARGS} --tokenizer-mode ${TOKENIZER_MODE}" +fi +if [ -n "${GPU_MEM_UTIL:-}" ]; then + OPTIONAL_ARGS="${OPTIONAL_ARGS} --gpu-memory-utilization ${GPU_MEM_UTIL}" +fi +if [ -n "${VLLM_EXTRA_ARGS:-}" ]; then + OPTIONAL_ARGS="${OPTIONAL_ARGS} ${VLLM_EXTRA_ARGS}" +fi + +# Use data-parallel or tensor-parallel based on which is set +if [ -n "${DATA_PARALLEL_SIZE:-}" ]; then + PARALLELISM_ARGS="--data-parallel-size ${DATA_PARALLEL_SIZE}" +else + PARALLELISM_ARGS="--tensor-parallel-size ${TP}" +fi + +# If BUILD_COMPILATION_CONFIG is set, construct the JSON here to avoid brace-expansion. +# NeMo Run writes sbatch env vars unquoted, so {"a":"b","c":["d"]} gets brace-expanded by bash. +# BUILD_COMPILATION_CONFIG carries just the cudagraph_mode string; we build the JSON safely. +if [ -z "${COMPILATION_CONFIG:-}" ] && [ -n "${BUILD_COMPILATION_CONFIG:-}" ]; then + COMPILATION_CONFIG="{\"cudagraph_mode\":\"${BUILD_COMPILATION_CONFIG}\",\"custom_ops\":[\"all\"]}" +fi # Start vLLM server (capture output for regression check parsing) +# Build command array so COMPILATION_CONFIG JSON is passed as a single properly-quoted arg +# (unquoted ${OPTIONAL_ARGS} expansion handles simple flags; JSON needs array quoting) VLLM_LOG=$(mktemp /tmp/vllm_server_XXXXXX.log) +VLLM_CMD=(vllm serve "${MODEL}" + --max-num-batched-tokens "${MAX_BATCHED_TOKENS:-32768}" + ${PARALLELISM_ARGS} + --port "${PORT}" + ${OPTIONAL_ARGS}) if [ -n "$SPEC_CONFIG" ]; then - vllm serve ${MODEL} \ - --speculative-config "${SPEC_CONFIG}" \ - --max-num-batched-tokens 32768 \ - --tensor-parallel-size ${TP} \ - --port ${PORT} \ - ${OPTIONAL_ARGS} \ - > >(tee -a "$VLLM_LOG") 2>&1 & -else - vllm serve ${MODEL} \ - --max-num-batched-tokens 32768 \ - --tensor-parallel-size ${TP} \ - --port ${PORT} \ - ${OPTIONAL_ARGS} \ - > >(tee -a "$VLLM_LOG") 2>&1 & + VLLM_CMD+=(--speculative-config "${SPEC_CONFIG}") +fi +if [ -n "${COMPILATION_CONFIG:-}" ]; then + VLLM_CMD+=(--compilation-config "${COMPILATION_CONFIG}") fi +"${VLLM_CMD[@]}" > >(tee -a "$VLLM_LOG") 2>&1 & SERVER_PID=$! -# Wait for server -echo "Waiting for vLLM server..." -for i in $(seq 1 180); do +# Wait for server (large models like DeepSeek V4 need up to 10 min to load + compile) +SERVER_TIMEOUT=${SERVER_TIMEOUT:-600} +echo "Waiting for vLLM server (timeout: ${SERVER_TIMEOUT}s)..." +for i in $(seq 1 ${SERVER_TIMEOUT}); do if curl -s http://localhost:${PORT}/health > /dev/null 2>&1; then echo "Server ready after ${i}s" break @@ -109,7 +575,7 @@ for i in $(seq 1 180); do done if ! curl -s http://localhost:${PORT}/health > /dev/null 2>&1; then - echo "ERROR: Server timeout"; exit 1 + echo "ERROR: Server did not become ready within ${SERVER_TIMEOUT}s"; exit 1 fi # Run quick test prompts using chat completions API @@ -138,7 +604,24 @@ for PROMPT in \ 2>/dev/null) END=$(date +%s%N) ELAPSED=$(echo "scale=2; ($END - $START) / 1000000000" | bc 2>/dev/null || echo "0") - TOKENS=$(echo "$RESULT" | python3 -c "import json,sys; r=json.load(sys.stdin); print(r.get('usage',{}).get('completion_tokens',0))" 2>/dev/null) + # Use python3 -S to skip site-packages (.pth startup files like _deepseek_v4_patch.pth + # print [force_af_v2] messages to stdout which corrupt the TOKENS variable). + TOKENS=$(echo "$RESULT" | python3 -S -c " +import json,sys +try: + r=json.load(sys.stdin) + u=r.get('usage') or {} + t=u.get('completion_tokens',0) or 0 + if not t: + msg = ((r.get('choices') or [{}])[0].get('message') or {}) + c = msg.get('content') or msg.get('reasoning_content') or '' + t = len(c.split()) if c else 0 + if not t and r.get('choices'): + t = 1 # any response with choices = success + print(t) +except Exception: + print(0) +" 2>/dev/null) if [ -n "$TOKENS" ] && [ "$TOKENS" -gt 0 ] 2>/dev/null; then TPS=$(echo "scale=1; $TOKENS / $ELAPSED" | bc 2>/dev/null || echo "?") echo " PASS: ${TOKENS} tokens in ${ELAPSED}s (${TPS} tok/s) — \"${PROMPT:0:50}...\"" diff --git a/tools/launcher/common/vllm/query.sh b/tools/launcher/common/vllm/query.sh index d1513623c34..be2764872fb 100755 --- a/tools/launcher/common/vllm/query.sh +++ b/tools/launcher/common/vllm/query.sh @@ -100,6 +100,221 @@ for arg in "$@"; do fi done +# B200: raise per-user process limit so concurrent deepgemm/NVCC JIT workers don't exhaust +# nproc when popen(nvcc) is called simultaneously across DP ranks during CUDA graph capture. +ulimit -u unlimited 2>/dev/null || true + +# B200: redirect deepgemm NVCC JIT to a writable+executable NFS path. /tmp (container tmpfs) +# is too small; /dev/shm is noexec. Use DEEPGEMM_TMPDIR (not TMPDIR) so enroot doesn't read +# it at container startup before the container starts. +if [ -n "${DEEPGEMM_TMPDIR:-}" ]; then + mkdir -p "$DEEPGEMM_TMPDIR" + export TMPDIR="$DEEPGEMM_TMPDIR" +fi + +# Copy model to /dev/shm to avoid NFS stale-handle errors when many workers mmap weights +# simultaneously during a long data synthesis run. Reuses existing copy if present. +if [ "${COPY_MODEL_TO_TMPFS:-0}" = "1" ]; then + MODEL_NAME=$(basename "$MODEL") + TMPFS_MODEL="/dev/shm/${MODEL_NAME}" + if [ -d "$TMPFS_MODEL" ] && [ -f "$TMPFS_MODEL/config.json" ]; then + echo "Using existing tmpfs model copy: $TMPFS_MODEL" + else + MODEL_SIZE=$(du -sh "$MODEL" 2>/dev/null | cut -f1 || echo "?") + AVAIL_SHM=$(df -h /dev/shm 2>/dev/null | tail -1 | awk '{print $4}' || echo "?") + echo "Copying model to /dev/shm (${MODEL_SIZE}, available: ${AVAIL_SHM})..." + cp -r "$MODEL" "$TMPFS_MODEL" + echo "Model copy done: $TMPFS_MODEL" + fi + MODEL="$TMPFS_MODEL" + echo "Loading from tmpfs: $MODEL" +fi + +# Force torch inductor to use the v2 auto_functionalized algorithm. +# vLLM explicitly sets enable_auto_functionalized_v2=False in its inductor config, +# which causes failures with fallback FP8 ops (e.g., when VLLM_USE_DEEP_GEMM=0): +# aten::as_strided() Expected a value of type 'List[int]' for argument 'stride' +# but instead found type 'list'. +# Set FORCE_AF_V2=1 to enable. Ported from common/specdec/vllm_smoke_test.sh. +if [ "${FORCE_AF_V2:-0}" = "1" ]; then + python3 << 'PYEOF' || true +import inspect, compileall, glob, re, os, site + +PATCH_MODULE_NAME = 'vllm_force_af_v2_runtime' +PATCH_CODE = r''' +import sys as _sys + +def _strip_af_v2_false(d): + if isinstance(d, dict) and d.get('enable_auto_functionalized_v2') is False: + d = {k: v for k, v in d.items() if k != 'enable_auto_functionalized_v2'} + print('[force_af_v2] Stripped enable_auto_functionalized_v2=False from inductor options', flush=True) + return d + +def _install(): + if _sys.modules.get('_vllm_af_v2_patched'): + return + _sys.modules['_vllm_af_v2_patched'] = True + + try: + import torch._dynamo as _dynamo + _orig_aot = _dynamo.aot_compile + def _patched_aot(*args, **kwargs): + if 'options' in kwargs: + kwargs['options'] = _strip_af_v2_false(kwargs['options']) + return _orig_aot(*args, **kwargs) + _dynamo.aot_compile = _patched_aot + print('[force_af_v2] Patched torch._dynamo.aot_compile', flush=True) + except Exception as e: + print(f'[force_af_v2] aot_compile patch failed: {e}', flush=True) + + try: + import torch._dynamo.aot_compile as _aot_mod + _orig_fg = _aot_mod.aot_compile_fullgraph + def _patched_fg(*args, **kwargs): + if 'options' in kwargs: + kwargs['options'] = _strip_af_v2_false(kwargs['options']) + return _orig_fg(*args, **kwargs) + _aot_mod.aot_compile_fullgraph = _patched_fg + print('[force_af_v2] Patched torch._dynamo.aot_compile_fullgraph', flush=True) + except Exception as e: + print(f'[force_af_v2] aot_compile_fullgraph patch failed: {e}', flush=True) + + try: + import torch._inductor.config as _ic + _orig_patch = _ic.patch + def _patched_patch(*args, **kwargs): + new_args = (_strip_af_v2_false(args[0]),) + args[1:] if args and isinstance(args[0], dict) else args + if kwargs.get('enable_auto_functionalized_v2') is False: + kwargs = {k: v for k, v in kwargs.items() if k != 'enable_auto_functionalized_v2'} + return _orig_patch(*new_args, **kwargs) + _ic.patch = _patched_patch + print('[force_af_v2] Patched torch._inductor.config.patch', flush=True) + except Exception as e: + print(f'[force_af_v2] config.patch intercept skipped: {e}', flush=True) + + try: + import torch._inductor.config as _ic + _ic.enable_auto_functionalized_v2 = True + print('[force_af_v2] Set torch._inductor.config.enable_auto_functionalized_v2 = True', flush=True) + except Exception as e: + print(f'[force_af_v2] inductor global config set failed: {e}', flush=True) + + try: + import torch._inductor as _ti_mod + _orig_sc = getattr(_ti_mod, 'standalone_compile', None) + if _orig_sc is not None: + def _patched_sc(fn, *args, **kwargs): + opts = kwargs.get('options') + if isinstance(opts, dict) and opts.get('enable_auto_functionalized_v2') is False: + kwargs['options'] = {k: v for k, v in opts.items() if k != 'enable_auto_functionalized_v2'} + print('[force_af_v2] Stripped enable_auto_functionalized_v2=False from standalone_compile', flush=True) + return _orig_sc(fn, *args, **kwargs) + _ti_mod.standalone_compile = _patched_sc + print('[force_af_v2] Patched torch._inductor.standalone_compile', flush=True) + except Exception as e: + print(f'[force_af_v2] standalone_compile patch failed: {e}', flush=True) + +_install() +''' + +site_dirs = site.getsitepackages() + [site.getusersitepackages()] +for sp in site_dirs: + if not os.path.isdir(sp): + continue + try: + mod_path = os.path.join(sp, f'{PATCH_MODULE_NAME}.py') + pth_path = os.path.join(sp, f'{PATCH_MODULE_NAME}.pth') + with open(mod_path, 'w') as f: + f.write(PATCH_CODE) + with open(pth_path, 'w') as f: + f.write(f'import {PATCH_MODULE_NAME}\n') + print(f'[force_af_v2] Wrote {pth_path} -> auto-loads in all worker processes') + except Exception as e: + print(f'[force_af_v2] Could not write to {sp}: {e}') + +exec(PATCH_CODE) + +vllm_dirs = [ + '/usr/local/lib/python3.12/dist-packages/vllm', + '/opt/venv/lib/python3.12/site-packages/vllm', +] +for vllm_dir in vllm_dirs: + if not os.path.isdir(vllm_dir): + continue + for py_file in glob.glob(os.path.join(vllm_dir, '**/*.py'), recursive=True): + if '__pycache__' in py_file: + continue + try: + with open(py_file) as f: + content = f.read() + if 'enable_auto_functionalized_v2' not in content: + continue + patched = re.sub( + r'("?enable_auto_functionalized_v2"?\s*[:=]\s*)False', + r'\1True', + content + ) + if '/vllm/config/compilation.py' in py_file or py_file.endswith('/compilation.py'): + patched2 = re.sub(r'\bKEY(\s*:\s*)False', r'KEY\1True', patched) + if patched2 != patched: + patched = patched2 + print(f'[force_af_v2] Patched KEY: False in {py_file}') + if patched != content: + with open(py_file, 'w') as f: + f.write(patched) + compileall.compile_file(py_file, quiet=2, force=True) + print(f'[force_af_v2] Patched source file: {py_file}') + except Exception as e: + print(f'[force_af_v2] Error processing {py_file}: {e}') + +try: + import torch._inductor.fx_passes.post_grad as pg + src_file = inspect.getfile(pg) + with open(src_file) as f: + content = f.read() + patterns = [ + ('raise AssertionError("auto_functionalized was not removed")', + 'pass # PATCHED: v1 nodes skipped (FORCE_AF_V2=1)'), + ('raise AssertionError("auto_functionalized_v2 was not removed")', + 'pass # PATCHED: v2 nodes skipped (FORCE_AF_V2=1)'), + ('if config.enable_auto_functionalized_v2:', 'if True: # PATCHED (FORCE_AF_V2=1)'), + ('if inductor_config.enable_auto_functionalized_v2:', 'if True: # PATCHED (FORCE_AF_V2=1)'), + ('GraphTransformObserver(gm, "decompose_triton_kernel_wrapper_functional").apply_graph_pass(decompose_triton_kernel_wrapper_functional)', + 'try:\n GraphTransformObserver(gm, "decompose_triton_kernel_wrapper_functional").apply_graph_pass(decompose_triton_kernel_wrapper_functional)\n except AssertionError as _af2_e:\n print(f"[force_af_v2] decompose_triton_kernel_wrapper_functional skipped: {_af2_e}", flush=True) # PATCHED'), + ] + patched = content + for old, new in patterns: + if old in patched: + patched = patched.replace(old, new) + if patched != content: + with open(src_file, 'w') as f: + f.write(patched) + compileall.compile_file(src_file, quiet=2, force=True) + print(f'[force_af_v2] Wrote and recompiled {src_file}') +except Exception as e: + print(f'[force_af_v2] post_grad.py patch failed: {e}') + +try: + import re as _re + import torch._inductor.pattern_matcher as pm + pm_file = inspect.getfile(pm) + with open(pm_file) as f: + pm_content = f.read() + pm_patched = _re.sub( + r'assert len\(graph_with_eager_vals\.graph\.nodes\) == len\(\s*\n\s*replacement\.graph\.nodes\s*\n\s*\)', + 'pass # PATCHED: skip node-count assertion for triton_kernel_wrapper_functional (FORCE_AF_V2=1)', + pm_content, + ) + if pm_patched != pm_content: + with open(pm_file, 'w') as f: + f.write(pm_patched) + compileall.compile_file(pm_file, quiet=2, force=True) + print(f'[force_af_v2] Patched pattern_matcher.py: {pm_file}') +except Exception as e: + print(f'[force_af_v2] pattern_matcher.py patch failed: {e}') +PYEOF +fi + # vLLM is single-process: GPU parallelism is handled internally via --tensor-parallel-size. # No MPI multi-rank logic needed; this script always runs as a single task. vllm serve \ diff --git a/tools/launcher/core.py b/tools/launcher/core.py index 8fd4e25ee79..cf89067c1e6 100644 --- a/tools/launcher/core.py +++ b/tools/launcher/core.py @@ -272,7 +272,8 @@ def build_slurm_executor( array=slurm_config.array, time=slurm_config.time, mem="0", - retries=0, + retries=slurm_config.retries, + additional_parameters={**(slurm_config.additional_parameters or {}), **({"requeue": True} if getattr(slurm_config, "requeue", False) else {})}, packager=packager, srun_args=slurm_config.srun_args, ) diff --git a/tools/launcher/examples/deepseek-ai/DeepSeek-V4-Flash/patch_vllm_dflash.py b/tools/launcher/examples/deepseek-ai/DeepSeek-V4-Flash/patch_vllm_dflash.py new file mode 100644 index 00000000000..641559db89b --- /dev/null +++ b/tools/launcher/examples/deepseek-ai/DeepSeek-V4-Flash/patch_vllm_dflash.py @@ -0,0 +1,509 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Patch vLLM for DeepSeek-V4-Flash + DFlash speculative decoding. + +Applies all patches needed to run DeepSeek-V4-Flash as the DFlash target model in +vLLM >= 0.1.dev15833. Idempotent — safe to run multiple times; each patch checks a +sentinel string before applying. + +Patches applied: + P0 speculative.py — add "deepseek_v4" to DFlash allowed target models + P1 deepseek_v4.py — aux hidden state collection in inner model forward loop + P2 deepseek_v4.py — return aux hidden states alongside hidden_states + P3 deepseek_v4.py — add set/get EAGLE3 interface methods to outer model + P4 gpu_model_runner.py — allow hasattr-based EAGLE3 interface check + P5 kv_cache_utils.py — leftover KV cache group for unassigned DFlash draft layers + P6 eagle.py — skip missing layers in validate_same_kv_cache_group + P7 kv_cache_utils.py — get_uniform_page_size: return min instead of asserting ==1 + P8 kv_cache_utils.py — _max_memory_usage_bytes_from_groups: handle mixed page sizes + P9 gpu_model_runner.py — _reshape_kv_cache_tensors: allow heterogeneous page sizes + P10 flash_attn.py — treat fp8_ds_mla as float8_e4m3fn in get_fp8_dtype_for_flashattn + P11 sparse_attn_indexer.py — bypass fp8_fp4_paged_mqa_logits (smem overflow on H100 w/ MLA) + +---------------------------------------------------------------------------- +Upstream PR strategy +---------------------------------------------------------------------------- +The patches split into two groups with different upstream paths: + +GROUP A — Core model support (P0, P1–P3, P4, P10): ~50 lines, PR-ready + These belong in a single vLLM PR: "Add DFlash speculative decoding support for + DeepSeek-V4 target model." + P0 One-liner: register "deepseek_v4" in the DFlash allowed-target list. + P1–P3 Add the aux hidden state interface (set/get_eagle3_aux_hidden_state_layers) + to DeepseekV4ForCausalLM — the same interface EAGLE3 requires. + P4 Make gpu_model_runner.py accept the hasattr-based interface in addition to + the formal supports_eagle3() check, so new models don't need to subclass. + P10 Fix fp8_ds_mla → float8_e4m3fn mapping in get_fp8_dtype_for_flashattn. + +GROUP B — KV cache heterogeneity (P5–P9): dissolve into proper draft architecture + These patches work around the fact that vLLM doesn't yet know that DFlash + cross-attention layers (which attend to the target's hidden states, not a separate + draft KV cache) are KV-cache-free. A proper upstream implementation would classify + those layers correctly at the KV cache spec level, eliminating the "leftover" + layers, the mixed page-size mismatch, and the reshape assert — all without needing + the five individual patches. + +P11 — Kernel fallback (sparse_attn_indexer.py): needs a kernel-level fix + The fp8_fp4_paged_mqa_logits DeepGEMM kernel exceeds H100 shared memory limits + (228 KB) when block_size=256 and MLA head_dim=576 bytes. The bypass here attends + all cached pages instead of running top-k selection — correct but suboptimal. + The right upstream fix is either: + (a) tile the DeepGEMM kernel so it fits in smem for large page sizes, or + (b) add an explicit runtime smem check in sparse_attn_indexer.py with a + documented fallback path (attend-all) and a one-time warning. + Option (b) is essentially this patch, just made explicit rather than silent. +---------------------------------------------------------------------------- +""" +import pathlib +import re +import sys + +VLLM = pathlib.Path("/usr/local/lib/python3.12/dist-packages/vllm") + + +def _delete_pyc(stem: str) -> None: + for pyc in VLLM.rglob(f"{stem}*.pyc"): + pyc.unlink(missing_ok=True) + + +def _patch_file(path: pathlib.Path, old: str, new: str, sentinel: str, label: str) -> bool: + src = path.read_text() + if sentinel in src: + print(f"{label}: already patched") + return True + if old not in src: + print(f"WARNING: {label}: pattern not found — skipping") + return False + path.write_text(src.replace(old, new, 1)) + _delete_pyc(path.stem) + print(f"{label}: OK") + return True + + +# --------------------------------------------------------------------------- +# P0: speculative.py — add deepseek_v4 to DFlash allowed target models +# --------------------------------------------------------------------------- +_spec = VLLM / "config" / "speculative.py" +_spec_src = _spec.read_text() +_p0_sentinel = '"deepseek_v4"' +if _p0_sentinel in _spec_src: + print("P0 speculative.py: already patched") +else: + _target = None + for i, line in enumerate(_spec_src.splitlines()): + if '"deepseek_v3"' in line: + _target = i + break + if _target is None: + print("WARNING: P0 speculative.py: deepseek_v3 line not found — skipping") + else: + lines = _spec_src.splitlines(keepends=True) + indent = len(lines[_target]) - len(lines[_target].lstrip()) + lines.insert(_target + 1, " " * indent + '"deepseek_v4",\n') + _spec.write_text("".join(lines)) + _delete_pyc("speculative") + print("P0 speculative.py: added deepseek_v4 after deepseek_v3 — OK") + +# --------------------------------------------------------------------------- +# P1-P3: deepseek_v4.py — EAGLE3/DFlash aux hidden state interface +# --------------------------------------------------------------------------- +_v4 = VLLM / "model_executor" / "models" / "deepseek_v4.py" +_v4_src = _v4.read_text() + +if "aux_hidden_state_layers" in _v4_src: + print("P1-P3 deepseek_v4.py: already patched") +else: + _old1 = ( + " for layer in islice(self.layers, self.start_layer, self.end_layer):\n" + " hidden_states = layer(\n" + " hidden_states,\n" + " positions,\n" + " input_ids,\n" + " )\n" + ) + _new1 = ( + " if not hasattr(self, 'aux_hidden_state_layers'):\n" + " self.aux_hidden_state_layers = ()\n" + " aux_hidden_states = []\n" + " for idx, layer in enumerate(\n" + " islice(self.layers, self.start_layer, self.end_layer),\n" + " start=self.start_layer,\n" + " ):\n" + " if idx in self.aux_hidden_state_layers:\n" + " aux_hidden_states.append(hidden_states.mean(dim=-2))\n" + " hidden_states = layer(\n" + " hidden_states,\n" + " positions,\n" + " input_ids,\n" + " )\n" + ) + _old2 = ( + " hidden_states = self.norm(hidden_states)\n" + " return hidden_states\n" + "\n" + " def load_weights(" + ) + _new2 = ( + " hidden_states = self.norm(hidden_states)\n" + " if aux_hidden_states:\n" + " return hidden_states, aux_hidden_states\n" + " return hidden_states\n" + "\n" + " def load_weights(" + ) + _eagle3_methods = ( + " def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:\n" + " self.model.aux_hidden_state_layers = layers\n" + "\n" + " def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:\n" + " num_layers = len(self.model.layers)\n" + " return (2, num_layers // 2, num_layers - 3)\n" + "\n" + ) + ok = True + if _old1 not in _v4_src: + print("WARNING: P1 deepseek_v4.py: inner loop pattern not found"); ok = False + if _old2 not in _v4_src: + print("WARNING: P2 deepseek_v4.py: return pattern not found"); ok = False + if ok: + _v4_src = _v4_src.replace(_old1, _new1, 1) + _v4_src = _v4_src.replace(_old2, _new2, 1) + # Insert methods before the first of these anchors in DeepseekV4ForCausalLM + _outer_idx = _v4_src.find("class DeepseekV4ForCausalLM(") + _outer = _v4_src[_outer_idx:] + _inserted = False + for _anchor in [" def compute_logits(", " def forward(", " def load_weights("]: + if _anchor in _outer: + _outer = _outer.replace(_anchor, _eagle3_methods + _anchor, 1) + _inserted = True + break + if not _inserted: + print("WARNING: P3 deepseek_v4.py: no anchor found for methods") + else: + _v4_src = _v4_src[:_outer_idx] + _outer + _v4.write_text(_v4_src) + _delete_pyc("deepseek_v4") + print("P1-P3 deepseek_v4.py: OK") + +# --------------------------------------------------------------------------- +# P4: gpu_model_runner.py — allow hasattr-based EAGLE3 interface +# --------------------------------------------------------------------------- +_gmr = VLLM / "v1" / "worker" / "gpu_model_runner.py" +_patch_file( + _gmr, + old=( + " if not supports_eagle3(self.get_model()):\n" + " raise RuntimeError(\n" + " \"Model does not support EAGLE3 interface but \"\n" + " \"aux_hidden_state_outputs was requested\"\n" + " )" + ), + new=( + " _m = self.get_model() # _eagle3_hasattr_patch\n" + " if not (supports_eagle3(_m) or\n" + " (hasattr(_m, 'set_aux_hidden_state_layers') and\n" + " hasattr(_m, 'get_eagle3_aux_hidden_state_layers'))):\n" + " raise RuntimeError(\n" + " \"Model does not support EAGLE3 interface but \"\n" + " \"aux_hidden_state_outputs was requested\"\n" + " )" + ), + sentinel="_eagle3_hasattr_patch", + label="P4 gpu_model_runner.py EAGLE3 check", +) + +# --------------------------------------------------------------------------- +# P5: kv_cache_utils.py — leftover KV cache group for unassigned DFlash layers +# --------------------------------------------------------------------------- +_kvu = VLLM / "v1" / "core" / "kv_cache_utils.py" +_patch_file( + _kvu, + old=( + " elif grouped_specs := group_and_unify_kv_cache_specs(kv_cache_spec):\n" + " # DeepseekV4 case: All layers need the same number of token slots,\n" + " # yet some layers are full attention while others are sliding window\n" + " # attention in different sizes. Need to group layers into multiple\n" + " # UniformTypeKVCacheSpecs.\n" + " kv_cache_groups = _get_kv_cache_groups_uniform_groups(grouped_specs)\n" + " _annotate_eagle_groups_deepseek_v4(vllm_config, kv_cache_spec, kv_cache_groups)\n" + " return kv_cache_groups" + ), + new=( + " elif grouped_specs := group_and_unify_kv_cache_specs(kv_cache_spec):\n" + " # DeepseekV4 case: All layers need the same number of token slots,\n" + " # yet some layers are full attention while others are sliding window\n" + " # attention in different sizes. Need to group layers into multiple\n" + " # UniformTypeKVCacheSpecs.\n" + " kv_cache_groups = _get_kv_cache_groups_uniform_groups(grouped_specs)\n" + " _annotate_eagle_groups_deepseek_v4(vllm_config, kv_cache_spec, kv_cache_groups)\n" + " # _dflash_leftover_patch: collect unassigned layers (e.g., Qwen3 GQA draft)\n" + " # and group them by page_size_bytes so each group is uniform.\n" + " _assigned = set(n for g in kv_cache_groups for n in g.layer_names)\n" + " _leftover = {k: v for k, v in kv_cache_spec.items() if k not in _assigned}\n" + " if _leftover:\n" + " print(f'kv_cache: creating leftover group for {len(_leftover)} unassigned layers')\n" + " from collections import defaultdict as _dd\n" + " _by_size = _dd(list)\n" + " for _ln, _sp in _leftover.items():\n" + " _by_size[_sp.page_size_bytes].append(_ln)\n" + " for _lnames in _by_size.values():\n" + " _g = {k: _leftover[k] for k in _lnames}\n" + " kv_cache_groups += create_kv_cache_group_specs(_g, [_lnames])\n" + " return kv_cache_groups" + ), + sentinel="_dflash_leftover_patch", + label="P5 kv_cache_utils.py leftover group", +) + +# --------------------------------------------------------------------------- +# P6: eagle.py — skip missing layers in validate_same_kv_cache_group +# --------------------------------------------------------------------------- +_eagle = VLLM / "v1" / "spec_decode" / "eagle.py" +_patch_file( + _eagle, + old=( + " assert (\n" + " len(\n" + " set(\n" + " [\n" + " kv_cache_groups[layer_name]\n" + " for layer_name in self._draft_attn_layer_names\n" + " ]\n" + " )\n" + " )\n" + " == 1\n" + " ), \"All drafting layers should belong to the same kv cache group\"" + ), + new=( + " # _dflash_group_patch: skip layers missing from kv_cache_groups (e.g., DFlash cross-attn)\n" + " _dgroup = set(\n" + " kv_cache_groups[n] for n in self._draft_attn_layer_names\n" + " if n in kv_cache_groups\n" + " )\n" + " assert len(_dgroup) <= 1, \"All drafting layers should belong to the same kv cache group\"" + ), + sentinel="_dflash_group_patch", + label="P6 eagle.py validate_same_kv_cache_group", +) + +# --------------------------------------------------------------------------- +# P7-P8: kv_cache_utils.py — mixed page size support +# --------------------------------------------------------------------------- +_kvu_src = _kvu.read_text() +if "_dflash_page_size_patch" in _kvu_src: + print("P7-P8 kv_cache_utils.py mixed page sizes: already patched") +else: + _changed = False + # P7: get_uniform_page_size — return min(page_sizes) instead of asserting len == 1 + _old7 = ( + " page_sizes = {layer.page_size_bytes for layer in kv_cache_specs}\n" + " assert len(page_sizes) == 1\n" + " return page_sizes.pop()" + ) + _new7 = ( + " page_sizes = {layer.page_size_bytes for layer in kv_cache_specs}\n" + " # _dflash_page_size_patch: allow mixed page sizes for DFlash heterogeneous draft\n" + " if not page_sizes:\n" + " return 0\n" + " return min(page_sizes)" + ) + if _old7 in _kvu_src: + _kvu_src = _kvu_src.replace(_old7, _new7, 1) + _changed = True + print("P7 kv_cache_utils.py get_uniform_page_size: OK") + else: + print("WARNING: P7 kv_cache_utils.py get_uniform_page_size: pattern not found") + + # P8: _max_memory_usage_bytes_from_groups — handle mixed page sizes + _old8 = ( + " # General case: group_size pools, each shared by one layer per group\n" + " # Memory = group_size * page_size * blocks_for_max_len\n" + " group_size = max(len(group.layer_names) for group in kv_cache_groups)\n" + " page_size = get_uniform_page_size(\n" + " [group.kv_cache_spec for group in kv_cache_groups]\n" + " )\n" + " blocks_needed = sum(\n" + " cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size)\n" + " for group in kv_cache_groups\n" + " )\n" + "\n" + " return group_size * page_size * blocks_needed" + ) + _new8 = ( + " # General case: group_size pools, each shared by one layer per group\n" + " # Memory = group_size * page_size * blocks_for_max_len\n" + " # _dflash_page_size_patch: handle mixed page sizes (DFlash heterogeneous draft)\n" + " _ps_set = set(g.kv_cache_spec.page_size_bytes for g in kv_cache_groups)\n" + " if len(_ps_set) == 1:\n" + " group_size = max(len(group.layer_names) for group in kv_cache_groups)\n" + " page_size = _ps_set.pop()\n" + " blocks_needed = sum(\n" + " cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size)\n" + " for group in kv_cache_groups\n" + " )\n" + " return group_size * page_size * blocks_needed\n" + " else:\n" + " # Mixed page sizes: sum per-group memory independently\n" + " return sum(\n" + " group.kv_cache_spec.max_memory_usage_bytes(vllm_config)\n" + " for group in kv_cache_groups\n" + " )" + ) + if _old8 in _kvu_src: + _kvu_src = _kvu_src.replace(_old8, _new8, 1) + _changed = True + print("P8 kv_cache_utils.py _max_memory_usage_bytes_from_groups: OK") + else: + print("WARNING: P8 kv_cache_utils.py _max_memory_usage_bytes_from_groups: pattern not found") + + if _changed: + _kvu.write_text(_kvu_src) + _delete_pyc("kv_cache_utils") + +# --------------------------------------------------------------------------- +# P9: gpu_model_runner.py — heterogeneous page sizes in _reshape_kv_cache_tensors +# --------------------------------------------------------------------------- +_patch_file( + _gmr, + old=( + " raw_tensor = kv_cache_raw_tensors[layer_name]\n" + " assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0\n" + " num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes" + ), + new=( + " raw_tensor = kv_cache_raw_tensors[layer_name]\n" + " # _dflash_reshape_patch: tolerate non-multiple sizes (heterogeneous draft)\n" + " _pg = kv_cache_spec.page_size_bytes\n" + " if raw_tensor.numel() % _pg != 0:\n" + " _nb = max(1, raw_tensor.numel() // _pg)\n" + " raw_tensor = raw_tensor[:_nb * _pg]\n" + " kv_cache_raw_tensors[layer_name] = raw_tensor\n" + " num_blocks = raw_tensor.numel() // _pg" + ), + sentinel="_dflash_reshape_patch", + label="P9 gpu_model_runner.py _reshape_kv_cache_tensors", +) + +# --------------------------------------------------------------------------- +# P10: flash_attn.py — treat fp8_ds_mla as float8_e4m3fn +# --------------------------------------------------------------------------- +_fa = VLLM / "v1" / "attention" / "backends" / "flash_attn.py" +_fa_src = _fa.read_text() +if "_dflash_fp8_ds_mla_patch" in _fa_src: + print("P10 flash_attn.py fp8_ds_mla: already patched") +else: + _fa_new = re.sub( + r"([ \t]*)raise ValueError\(f\"Unrecognized FP8 dtype: \{kv_cache_dtype\}\"\)", + lambda m: ( + m.group(1) + "# _dflash_fp8_ds_mla_patch: fp8_ds_mla is e4m3fn stored by the compressor\n" + + m.group(1) + "if kv_cache_dtype == \"fp8_ds_mla\":\n" + + m.group(1) + " import torch as _t; return _t.float8_e4m3fn\n" + + m.group(1) + "raise ValueError(f\"Unrecognized FP8 dtype: {kv_cache_dtype}\")" + ), + _fa_src, + count=1, + ) + if _fa_new != _fa_src: + _fa.write_text(_fa_new) + _delete_pyc("flash_attn") + print("P10 flash_attn.py fp8_ds_mla: OK") + else: + print("WARNING: P10 flash_attn.py: raise ValueError pattern not found — skipping") + +# --------------------------------------------------------------------------- +# P11: sparse_attn_indexer.py — bypass fp8_fp4_paged_mqa_logits (smem overflow) +# +# DeepSeek V4 Flash MLA uses block_size=256, head_dim=576 bytes/token. The +# fp8_fp4_paged_mqa_logits DeepGEMM kernel exceeds H100 shared memory limits +# (228 KB) with this configuration. Replace the logits + top_k path with a +# direct fill of topk_indices from the block_table, which attends to all cached +# pages and avoids the large intermediate logits tensor. +# --------------------------------------------------------------------------- +_sai = VLLM / "model_executor" / "layers" / "sparse_attn_indexer.py" +_patch_file( + _sai, + old=( + " logits = fp8_fp4_paged_mqa_logits(\n" + " (padded_q_quant_cast, padded_q_scale),\n" + " kv_cache,\n" + " weights[:num_padded_tokens],\n" + " seq_lens,\n" + " decode_metadata.block_table,\n" + " decode_metadata.schedule_metadata,\n" + " max_model_len=max_model_len,\n" + " clean_logits=False,\n" + " )\n" + " num_rows = logits.shape[0]\n" + " topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]\n" + "\n" + " if current_platform.is_cuda() and topk_tokens in (512, 1024, 2048):\n" + " workspace_manager = current_workspace_manager()\n" + " (topk_workspace,) = workspace_manager.get_simultaneous(\n" + " ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),\n" + " )\n" + " torch.ops._C.persistent_topk(\n" + " logits,\n" + " seq_lens,\n" + " topk_indices,\n" + " topk_workspace,\n" + " topk_tokens,\n" + " attn_metadata.max_seq_len,\n" + " )\n" + " else:\n" + " if current_platform.is_xpu():\n" + " ops.top_k_per_row_decode(\n" + " logits,\n" + " next_n,\n" + " seq_lens,\n" + " topk_indices,\n" + " num_rows,\n" + " logits.stride(0),\n" + " logits.stride(1),\n" + " topk_tokens,\n" + " )\n" + " else:\n" + " torch.ops._C.top_k_per_row_decode(\n" + " logits,\n" + " next_n,\n" + " seq_lens,\n" + " topk_indices,\n" + " num_rows,\n" + " logits.stride(0),\n" + " logits.stride(1),\n" + " topk_tokens,\n" + " )" + ), + new=( + " # _dflash_smem_fallback_patch: bypass fp8_fp4_paged_mqa_logits + top_k\n" + " # Directly fill topk_indices with block_table entries (attend to all pages).\n" + " # This avoids the (num_tokens x num_total_blocks) logits tensor and the\n" + " # fp8_fp4_paged_mqa_logits kernel that overflows H100 smem with MLA block_size=256.\n" + " topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]\n" + " topk_indices.fill_(-1)\n" + " _bt_flat = (\n" + " decode_metadata.block_table[:batch_size]\n" + " .unsqueeze(1)\n" + " .expand(-1, next_n, -1)\n" + " .reshape(num_padded_tokens, -1)\n" + " )\n" + " _max_bl = min(_bt_flat.shape[1], topk_tokens)\n" + " topk_indices[:, :_max_bl] = _bt_flat[:, :_max_bl]" + ), + sentinel="_dflash_smem_fallback_patch", + label="P11 sparse_attn_indexer.py smem fallback", +) + +print("All patches done!") diff --git a/tools/launcher/examples/deepseek-ai/DeepSeek-V4-Flash/vllm_dflash_smoke_test_cw_dfw.yaml b/tools/launcher/examples/deepseek-ai/DeepSeek-V4-Flash/vllm_dflash_smoke_test_cw_dfw.yaml new file mode 100644 index 00000000000..b721fc7ecbe --- /dev/null +++ b/tools/launcher/examples/deepseek-ai/DeepSeek-V4-Flash/vllm_dflash_smoke_test_cw_dfw.yaml @@ -0,0 +1,60 @@ +# vLLM DFlash smoke test for deepseek-ai/DeepSeek-V4-Flash — CW-DFW H100 variant. +# +# Launches a vLLM server with DeepSeek-V4-Flash as the target model and a +# small randomly-initialized DFlash draft model (4-layer Qwen3-based scaffold, +# created from z-lab's DFlash architecture reference) to verify the inference +# pipeline starts and generates tokens end-to-end. +# +# NOTE: the draft model uses random weights — this smoke test validates the +# DFlash inference stack (patching, KV cache allocation, speculative decoding +# loop), NOT acceptance rate or generation quality. Replace draft_model with +# a properly trained checkpoint for production use. +# +# This container (deepseekv4-cu130) does not yet natively support deepseek_v4 as a +# DFlash target. patch_vllm_dflash.py bridges the gap by patching vLLM at startup. +# +# Key config choices: +# block_size=256 — required: SWA layers use window_size=256 so the page constraint +# max(sm_page_sizes) ≤ max(all_page_sizes) forces block_size ≥ 256 +# gpu_memory_utilization=0.85 — leaves ~4 GB headroom per GPU for Triton JIT compilation +# of the DeepSeek compressor kernel on first inference +# max_num_batched_tokens=4096 — reduces profiling-phase memory pressure +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2502.06036) +# +# Usage (nmm-sandbox): +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/deepseek-ai/DeepSeek-V4-Flash/vllm_dflash_smoke_test_cw_dfw.yaml --yes +# +# Usage (ModelOpt launcher): +# uv run launch.py --yaml examples/deepseek-ai/DeepSeek-V4-Flash/vllm_dflash_smoke_test_cw_dfw.yaml --yes + +job_name: DeepSeek-V4-Flash_DFlash_smoke_cw_dfw + +pipeline: + global_vars: + hf_model: /hf-local/deepseek-ai/DeepSeek-V4-Flash + draft_model: /hf-local/z-lab/DeepSeek-V4-Flash-DFlash + + task_0: + script: common/specdec/vllm_smoke_test.sh + environment: + - HF_MODEL_CKPT: <> + - DRAFT_MODEL: <> + - SPEC_METHOD: "dflash" + - NUM_SPEC_TOKENS: "15" + - TP_SIZE: "8" + - KV_CACHE_DTYPE: "fp8" + - BLOCK_SIZE: "256" + - TRUST_REMOTE_CODE: "1" + - VLLM_USE_DEEP_GEMM: "1" + - FORCE_AF_V2: "1" + - GPU_MEM_UTIL: "0.85" + - MAX_BATCHED_TOKENS: "4096" + - COPY_MODEL_TO_TMPFS: "1" + - VLLM_PATCH_SCRIPT: "examples/deepseek-ai/DeepSeek-V4-Flash/patch_vllm_dflash.py" + slurm_config: + _factory_: "cw_dfw_slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + container: "vllm/vllm-openai:deepseekv4-cu130" diff --git a/tools/launcher/slurm_config.py b/tools/launcher/slurm_config.py index d2a8cd48d11..2c31bfda1d2 100644 --- a/tools/launcher/slurm_config.py +++ b/tools/launcher/slurm_config.py @@ -38,6 +38,9 @@ class SlurmConfig: container_mounts: list[str] = None srun_args: list[str] = None array: str = None + retries: int = 0 + requeue: bool = False + additional_parameters: dict = None nodes: int = 1 ntasks_per_node: int = 1 gpus_per_node: int = 1 @@ -61,6 +64,8 @@ def slurm_factory( ], srun_args: list[str] = ["--no-container-mount-home"], array: str = None, # noqa: RUF013 + retries: int = 0, + requeue: bool = False, time: str = "04:00:00", ) -> SlurmConfig: """Generic Slurm factory — configure via environment variables or CLI overrides.""" @@ -76,5 +81,7 @@ def slurm_factory( container_mounts=container_mounts, srun_args=srun_args, array=array, + retries=retries, + requeue=requeue, time=time, )