Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tools/launcher/common/specdec/read_vllm_files.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash
set -euo pipefail
echo "=== pattern_matcher.py lines 305-325 ==="
sed -n '305,325p' /usr/local/lib/python3.12/dist-packages/torch/_inductor/pattern_matcher.py 2>/dev/null || echo "NOT FOUND"
echo "=== post_grad.py lines 345-375 ==="
sed -n '345,375p' /usr/local/lib/python3.12/dist-packages/torch/_inductor/fx_passes/post_grad.py 2>/dev/null || echo "NOT FOUND"
echo "=== post_grad.py lines 1240-1260 ==="
sed -n '1240,1260p' /usr/local/lib/python3.12/dist-packages/torch/_inductor/fx_passes/post_grad.py 2>/dev/null || echo "NOT FOUND"
echo "=== DONE ==="
523 changes: 503 additions & 20 deletions tools/launcher/common/specdec/vllm_smoke_test.sh

Large diffs are not rendered by default.

215 changes: 215 additions & 0 deletions tools/launcher/common/vllm/query.sh
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,221 @@ for arg in "$@"; do
fi
done

# B200: raise per-user process limit so concurrent deepgemm/NVCC JIT workers don't exhaust
# nproc when popen(nvcc) is called simultaneously across DP ranks during CUDA graph capture.
ulimit -u unlimited 2>/dev/null || true

# B200: redirect deepgemm NVCC JIT to a writable+executable NFS path. /tmp (container tmpfs)
# is too small; /dev/shm is noexec. Use DEEPGEMM_TMPDIR (not TMPDIR) so enroot doesn't read
# it at container startup before the container starts.
if [ -n "${DEEPGEMM_TMPDIR:-}" ]; then
mkdir -p "$DEEPGEMM_TMPDIR"
export TMPDIR="$DEEPGEMM_TMPDIR"
fi

# Copy model to /dev/shm to avoid NFS stale-handle errors when many workers mmap weights
# simultaneously during a long data synthesis run. Reuses existing copy if present.
if [ "${COPY_MODEL_TO_TMPFS:-0}" = "1" ]; then
MODEL_NAME=$(basename "$MODEL")
TMPFS_MODEL="/dev/shm/${MODEL_NAME}"
if [ -d "$TMPFS_MODEL" ] && [ -f "$TMPFS_MODEL/config.json" ]; then
echo "Using existing tmpfs model copy: $TMPFS_MODEL"
else
MODEL_SIZE=$(du -sh "$MODEL" 2>/dev/null | cut -f1 || echo "?")
AVAIL_SHM=$(df -h /dev/shm 2>/dev/null | tail -1 | awk '{print $4}' || echo "?")
echo "Copying model to /dev/shm (${MODEL_SIZE}, available: ${AVAIL_SHM})..."
cp -r "$MODEL" "$TMPFS_MODEL"
echo "Model copy done: $TMPFS_MODEL"
fi
MODEL="$TMPFS_MODEL"
echo "Loading from tmpfs: $MODEL"
fi

# Force torch inductor to use the v2 auto_functionalized algorithm.
# vLLM explicitly sets enable_auto_functionalized_v2=False in its inductor config,
# which causes failures with fallback FP8 ops (e.g., when VLLM_USE_DEEP_GEMM=0):
# aten::as_strided() Expected a value of type 'List[int]' for argument 'stride'
# but instead found type 'list'.
# Set FORCE_AF_V2=1 to enable. Ported from common/specdec/vllm_smoke_test.sh.
if [ "${FORCE_AF_V2:-0}" = "1" ]; then
python3 << 'PYEOF' || true
import inspect, compileall, glob, re, os, site

PATCH_MODULE_NAME = 'vllm_force_af_v2_runtime'
PATCH_CODE = r'''
import sys as _sys

def _strip_af_v2_false(d):
if isinstance(d, dict) and d.get('enable_auto_functionalized_v2') is False:
d = {k: v for k, v in d.items() if k != 'enable_auto_functionalized_v2'}
print('[force_af_v2] Stripped enable_auto_functionalized_v2=False from inductor options', flush=True)
return d

def _install():
if _sys.modules.get('_vllm_af_v2_patched'):
return
_sys.modules['_vllm_af_v2_patched'] = True

try:
import torch._dynamo as _dynamo
_orig_aot = _dynamo.aot_compile
def _patched_aot(*args, **kwargs):
if 'options' in kwargs:
kwargs['options'] = _strip_af_v2_false(kwargs['options'])
return _orig_aot(*args, **kwargs)
_dynamo.aot_compile = _patched_aot
print('[force_af_v2] Patched torch._dynamo.aot_compile', flush=True)
except Exception as e:
print(f'[force_af_v2] aot_compile patch failed: {e}', flush=True)

try:
import torch._dynamo.aot_compile as _aot_mod
_orig_fg = _aot_mod.aot_compile_fullgraph
def _patched_fg(*args, **kwargs):
if 'options' in kwargs:
kwargs['options'] = _strip_af_v2_false(kwargs['options'])
return _orig_fg(*args, **kwargs)
_aot_mod.aot_compile_fullgraph = _patched_fg
print('[force_af_v2] Patched torch._dynamo.aot_compile_fullgraph', flush=True)
except Exception as e:
print(f'[force_af_v2] aot_compile_fullgraph patch failed: {e}', flush=True)

try:
import torch._inductor.config as _ic
_orig_patch = _ic.patch
def _patched_patch(*args, **kwargs):
new_args = (_strip_af_v2_false(args[0]),) + args[1:] if args and isinstance(args[0], dict) else args
if kwargs.get('enable_auto_functionalized_v2') is False:
kwargs = {k: v for k, v in kwargs.items() if k != 'enable_auto_functionalized_v2'}
return _orig_patch(*new_args, **kwargs)
_ic.patch = _patched_patch
print('[force_af_v2] Patched torch._inductor.config.patch', flush=True)
except Exception as e:
print(f'[force_af_v2] config.patch intercept skipped: {e}', flush=True)

try:
import torch._inductor.config as _ic
_ic.enable_auto_functionalized_v2 = True
print('[force_af_v2] Set torch._inductor.config.enable_auto_functionalized_v2 = True', flush=True)
except Exception as e:
print(f'[force_af_v2] inductor global config set failed: {e}', flush=True)

try:
import torch._inductor as _ti_mod
_orig_sc = getattr(_ti_mod, 'standalone_compile', None)
if _orig_sc is not None:
def _patched_sc(fn, *args, **kwargs):
opts = kwargs.get('options')
if isinstance(opts, dict) and opts.get('enable_auto_functionalized_v2') is False:
kwargs['options'] = {k: v for k, v in opts.items() if k != 'enable_auto_functionalized_v2'}
print('[force_af_v2] Stripped enable_auto_functionalized_v2=False from standalone_compile', flush=True)
return _orig_sc(fn, *args, **kwargs)
_ti_mod.standalone_compile = _patched_sc
print('[force_af_v2] Patched torch._inductor.standalone_compile', flush=True)
except Exception as e:
print(f'[force_af_v2] standalone_compile patch failed: {e}', flush=True)

_install()
'''

site_dirs = site.getsitepackages() + [site.getusersitepackages()]
for sp in site_dirs:
if not os.path.isdir(sp):
continue
try:
mod_path = os.path.join(sp, f'{PATCH_MODULE_NAME}.py')
pth_path = os.path.join(sp, f'{PATCH_MODULE_NAME}.pth')
with open(mod_path, 'w') as f:
f.write(PATCH_CODE)
with open(pth_path, 'w') as f:
f.write(f'import {PATCH_MODULE_NAME}\n')
print(f'[force_af_v2] Wrote {pth_path} -> auto-loads in all worker processes')
except Exception as e:
print(f'[force_af_v2] Could not write to {sp}: {e}')

exec(PATCH_CODE)

vllm_dirs = [
'/usr/local/lib/python3.12/dist-packages/vllm',
'/opt/venv/lib/python3.12/site-packages/vllm',
]
for vllm_dir in vllm_dirs:
if not os.path.isdir(vllm_dir):
continue
for py_file in glob.glob(os.path.join(vllm_dir, '**/*.py'), recursive=True):
if '__pycache__' in py_file:
continue
try:
with open(py_file) as f:
content = f.read()
if 'enable_auto_functionalized_v2' not in content:
continue
patched = re.sub(
r'("?enable_auto_functionalized_v2"?\s*[:=]\s*)False',
r'\1True',
content
)
if '/vllm/config/compilation.py' in py_file or py_file.endswith('/compilation.py'):
patched2 = re.sub(r'\bKEY(\s*:\s*)False', r'KEY\1True', patched)
if patched2 != patched:
patched = patched2
print(f'[force_af_v2] Patched KEY: False in {py_file}')
if patched != content:
with open(py_file, 'w') as f:
f.write(patched)
compileall.compile_file(py_file, quiet=2, force=True)
print(f'[force_af_v2] Patched source file: {py_file}')
except Exception as e:
print(f'[force_af_v2] Error processing {py_file}: {e}')

try:
import torch._inductor.fx_passes.post_grad as pg
src_file = inspect.getfile(pg)
with open(src_file) as f:
content = f.read()
patterns = [
('raise AssertionError("auto_functionalized was not removed")',
'pass # PATCHED: v1 nodes skipped (FORCE_AF_V2=1)'),
('raise AssertionError("auto_functionalized_v2 was not removed")',
'pass # PATCHED: v2 nodes skipped (FORCE_AF_V2=1)'),
('if config.enable_auto_functionalized_v2:', 'if True: # PATCHED (FORCE_AF_V2=1)'),
('if inductor_config.enable_auto_functionalized_v2:', 'if True: # PATCHED (FORCE_AF_V2=1)'),
('GraphTransformObserver(gm, "decompose_triton_kernel_wrapper_functional").apply_graph_pass(decompose_triton_kernel_wrapper_functional)',
'try:\n GraphTransformObserver(gm, "decompose_triton_kernel_wrapper_functional").apply_graph_pass(decompose_triton_kernel_wrapper_functional)\n except AssertionError as _af2_e:\n print(f"[force_af_v2] decompose_triton_kernel_wrapper_functional skipped: {_af2_e}", flush=True) # PATCHED'),
]
patched = content
for old, new in patterns:
if old in patched:
patched = patched.replace(old, new)
if patched != content:
with open(src_file, 'w') as f:
f.write(patched)
compileall.compile_file(src_file, quiet=2, force=True)
print(f'[force_af_v2] Wrote and recompiled {src_file}')
except Exception as e:
print(f'[force_af_v2] post_grad.py patch failed: {e}')

try:
import re as _re
import torch._inductor.pattern_matcher as pm
pm_file = inspect.getfile(pm)
with open(pm_file) as f:
pm_content = f.read()
pm_patched = _re.sub(
r'assert len\(graph_with_eager_vals\.graph\.nodes\) == len\(\s*\n\s*replacement\.graph\.nodes\s*\n\s*\)',
'pass # PATCHED: skip node-count assertion for triton_kernel_wrapper_functional (FORCE_AF_V2=1)',
pm_content,
)
if pm_patched != pm_content:
with open(pm_file, 'w') as f:
f.write(pm_patched)
compileall.compile_file(pm_file, quiet=2, force=True)
print(f'[force_af_v2] Patched pattern_matcher.py: {pm_file}')
except Exception as e:
print(f'[force_af_v2] pattern_matcher.py patch failed: {e}')
PYEOF
fi

# vLLM is single-process: GPU parallelism is handled internally via --tensor-parallel-size.
# No MPI multi-rank logic needed; this script always runs as a single task.
vllm serve \
Expand Down
3 changes: 2 additions & 1 deletion tools/launcher/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ def build_slurm_executor(
array=slurm_config.array,
time=slurm_config.time,
mem="0",
retries=0,
retries=slurm_config.retries,
additional_parameters={**(slurm_config.additional_parameters or {}), **({"requeue": True} if getattr(slurm_config, "requeue", False) else {})},
packager=packager,
srun_args=slurm_config.srun_args,
)
Expand Down
Loading
Loading