diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 952144a8..00000000 --- a/.gitmodules +++ /dev/null @@ -1,4 +0,0 @@ -[submodule "third_party/TransformerEngine"] - path = third_party/TransformerEngine - url = https://github.com/ROCm/TransformerEngine.git - branch = dev diff --git a/amd/.gitignore b/amd/.gitignore deleted file mode 100644 index 39ea67da..00000000 --- a/amd/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -assets/ -runs/ -logs/ -*.log diff --git a/amd/README.md b/amd/README.md deleted file mode 100644 index 23100ea4..00000000 --- a/amd/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# AMD ROCm Validation - -This directory tracks the AMD/ROCm validation work for Relax. - -## Current Target - -- Task: Qwen3.5-9B DAPO-Math -- Hardware: AMD Instinct MI355X, single node, 8 GPUs -- Launch script: `scripts/training/text/run-qwen35-9B-8xgpu-async.sh` -- Mode: fully async - -## Files - -- `qwen35-9b-dapo-math.md`: runbook and experiment notes for the current validation. diff --git a/amd/qwen3-4b-dapo-math.md b/amd/qwen3-4b-dapo-math.md deleted file mode 100644 index 9ab0f925..00000000 --- a/amd/qwen3-4b-dapo-math.md +++ /dev/null @@ -1,258 +0,0 @@ -# Qwen3-4B DAPO-Math on AMD ROCm - -## Goal - -Validate a small dense Qwen3 model on AMD ROCm before returning to Qwen3.5 backend enablement. - -## Choice - -- Model: `Qwen/Qwen3-4B` -- Task: DAPO-Math -- Mode: fully async -- GPUs: 4 -- Base recipe: `scripts/training/text/run-qwen3-4B-4xgpu-async.sh` - -This is preferred over Qwen3.5 for the first AMD smoke because it is dense and does not require Qwen3.5 GatedDeltaNet / experimental attention support. - -## Assets - -Assets are stored under `/data/models/minimax/Relax/amd/assets/exps`: - -```text -Qwen3-4B/ -dapo-math-17k/dapo-math-17k.jsonl -aime-2024/aime-2024.jsonl -``` - -## Runner - -Use: - -```bash -/data/models/minimax/Relax/amd/run-qwen3-4b-dapo-math-direct.sh -``` - -The runner connects directly to a Ray cluster through `RAY_ADDRESS`, avoiding Ray Jobs API issues observed during the Qwen3.5 run. - -## Experiment Log - -### 2026-04-26 - -- Downloaded `Qwen/Qwen3-4B` to `amd/assets/exps/Qwen3-4B`. -- Verified config and tokenizer load as `Qwen3Config`. -- Added `amd/run-qwen3-4b-dapo-math-direct.sh`. -- Patched `relax.utils.utils.get_serve_url()` to respect `RELAX_SERVE_PORT` instead of hard-coding `8000`. -- Started an independent 4-GPU Ray cluster: - - GCS: `10.235.26.199:6380` - - Dashboard/job server: `http://10.235.26.199:8266` - - Dashboard agent HTTP/gRPC: `8267` / `8268` - - Worker ports: `30000-65000` - - Visible GPUs: `0,1,2,3` -- Verified Ray reports `4.0 GPU` and `ray job list` is empty. -- Launched `amd/run-qwen3-4b-dapo-math-direct.sh`. -- The run progressed through: - - Ray initialization - - Ray Serve startup on `RELAX_SERVE_PORT=18081` - - DCSCoordinator deployment - - 4-GPU resource validation - - placement groups for actor, rollout, reference, actor_fwd - - streaming index build for `dapo-math-17k` (`17398` lines) -- The run failed during rollout service deployment: - - `ModuleNotFoundError: No module named 'sgl_kernel'` - - The failure happens while SGLang imports its quantization/MoE runner modules during `ServerArgs` / model config initialization. -- Cleanup: - - Stopped the independent 4-GPU Ray cluster on `10.235.26.199:6380`. - - Existing 2-GPU Ray cluster on `127.0.0.1:6379` is still running. - - GPU utilization returned to 0%. - -## Current Blocker - -Qwen3-4B avoids the Qwen3.5 Megatron experimental-attention blocker. The current blocker is the missing SGLang runtime package `sgl_kernel`. - -`pip index versions sgl-kernel` shows the package is available, latest `0.3.21`, but it is not installed in this ROCm image. - -## Dockerfile Update - -`docker/Dockerfile.rocm` has been updated to follow SGLang's ROCm build flow for the minimum required kernel package: - -1. Build `sgl-kernel` from the checked-out SGLang source. -2. Use `sgl-kernel/pyproject_rocm.toml`. -3. Run `AMDGPU_TARGET= python setup_rocm.py install`. -4. Install SGLang Python from `python[srt_hip]` after the kernel build. - -This should be rebuilt into the image before rerunning Qwen3-4B. - -## Current Container Hotfix - -Applied the same minimal SGLang ROCm kernel flow directly in the current container: - -```bash -cd /sgl-workspace/sglang -python -m pip install --upgrade pip setuptools wheel setuptools_scm scikit-build-core pybind11 -cd sgl-kernel -rm -f pyproject.toml -cp pyproject_rocm.toml pyproject.toml -AMDGPU_TARGET=gfx950 python setup_rocm.py install -cd .. -cp python/pyproject_other.toml python/pyproject.toml -pip install -e "python[srt_hip]" --no-build-isolation -``` - -Verified: - -- `sgl-kernel`: `0.3.21.post1` -- `sglang`: `0.5.6.post3.dev2790+g476d371a4` -- `sgl_kernel` imports from `/opt/venv/lib/python3.10/site-packages/sgl_kernel/__init__.py` -- SGLang ROCm runtime env vars are exported from `/root/.bashrc` -- Reran Qwen3-4B after installing `sgl-kernel`; this passed the previous SGLang import failure. -- New failure: - - `ValueError: apply_rope_fusion is not available. Please install TE >= 1.4.` - - Cause: Megatron Bridge provider enabled RoPE fusion, but TransformerEngine is unavailable in this ROCm image. - - Fix: added `--no-rope-fusion` to `amd/run-qwen3-4b-dapo-math-direct.sh`. -- Reran with `--no-rope-fusion`; RoPE fusion error passed. -- New failure: - - `NameError: name 'TESpecProvider' is not defined` - - Cause: Bridge still selected TransformerEngine layer spec while TE is unavailable. - - Fix: added `--transformer-impl local` to the runner. -- Reran with `--transformer-impl local`; local spec was selected successfully. -- New failures: - - Actor side: `TypeError: '>=' not supported between instances of 'NoneType' and 'Version'` from Megatron optimizer TE version check. - - Rollout side: `ModuleNotFoundError: No module named 'aiter'`. - - Fixes: - - Removed `--use-precision-aware-optimizer` from the runner. - - Install AITER in the current container following SGLang ROCm Dockerfile. -- Installed AITER in the current container: - - Source: `https://github.com/ROCm/aiter.git` - - Commit/tag: `v0.1.11.post1` - - Build env: `PREBUILD_KERNELS=1 GPU_ARCHS=gfx950` - - Verified `import aiter` from `/sgl-workspace/aiter/aiter/__init__.py` -- Reran after AITER install; AITER import worked. -- New validation failure: - - `--optimizer-cpu-offload` requires `--use-precision-aware-optimizer` - - But `--use-precision-aware-optimizer` is incompatible with this no-TE environment. - - Fix: removed `--optimizer-cpu-offload` and `--overlap-cpu-optimizer-d2h-h2d` from the 4B smoke runner. -- Reran after removing optimizer CPU offload: - - SGLang rollout service deployed successfully. - - SGLang `/health`, `/server_info`, and `/model_info` responded. - - DAPO streaming dataset and rollout manager initialized. -- Current blocker: - - Actor/reference/actor_fwd fail while loading HF weights through Megatron Bridge. - - Error: `AttributeError: 'NoneType' object has no attribute 'megatron_module'`. - - Preceded by repeated warnings like `No mapping found for megatron_param: decoder.layers.*.pre_mlp_layernorm.weight`. - - This indicates a Megatron Bridge conversion mapping issue for Qwen3-4B under the no-TE local spec path. -- Cleanup: - - Stopped the independent 4-GPU Ray cluster. - -## ROCm TransformerEngine Submodule - -Added ROCm TransformerEngine as a submodule: - -```text -third_party/TransformerEngine -> https://github.com/ROCm/TransformerEngine.git (branch: dev) -``` - -ROCm TE docs describe two install paths. - -Wheel install for ROCm 7.2: - -```bash -wget -r -l1 -nd -A 'transformer_engine*' https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/ -pip install ./transformer_engine* --no-build-isolation -``` - -Source install from the submodule: - -```bash -cd third_party/TransformerEngine -git submodule update --init --recursive -export NVTE_FRAMEWORK=pytorch -export NVTE_ROCM_ARCH=gfx950 -export NVTE_USE_ROCM=1 -pip install --no-build-isolation . -``` - -If the HIP compiler cannot detect the platform, also export: - -```bash -export HIP_PLATFORM=amd -``` - -## Current Container TransformerEngine Install - -Installed ROCm TE wheels for ROCm 7.2 into the current container: - -```bash -mkdir -p /tmp/te-rocm72 -cd /tmp/te-rocm72 -wget -r -l1 -nd -A 'transformer_engine*' https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/ -pip install ./transformer_engine-2.4.0-py3-none-any.whl \ - ./transformer_engine_rocm-2.4.0-py3-none-manylinux_2_28_x86_64.whl \ - ./transformer_engine_torch-2.4.0.tar.gz \ - --no-build-isolation -``` - -Verified: - -- `transformer_engine`: `2.4.0` -- `transformer_engine_torch`: `2.4.0` -- `transformer_engine.pytorch.LayerNormLinear`: available -- `transformer_engine.pytorch.RMSNorm`: available -- `transformer_engine.pytorch.DotProductAttention`: available - -After TE install, removed the no-TE runner flags: - -- `--no-rope-fusion` -- `--transformer-impl local` -- Reran Qwen3-4B with TE installed: - - Megatron Bridge mapping issue passed. - - `actor`, `actor_fwd`, `reference`, `rollout`, and `advantages` all registered successfully. - - Rollout completed and transferred data. - - Actor-to-rollout weight update completed. - - Training reached actor/reference log-prob and actor train step. -- New blocker: - - `ValueError: No dot product attention backend is available for the provided inputs.` - - Raised from `transformer_engine.pytorch.attention.dot_product_attention.DotProductAttention.forward`. - - Reference service became unhealthy and triggered global restart. - - Next debugging direction: run with `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2` to see why TE disables all attention backends, or switch training attention backend away from TE fused attention for this smoke. -- NVTE debug / static backend selection showed the likely backend issue: - - Current training args use `qkv_format=thd`. - - TE ROCm selector reports `thd_thd_thd + causal` has no backend. - - `thd` disables `UnfusedDotProductAttention`. - - `sbhd` / `bshd` layouts have available fused or unfused backends. - - Fix under test: set `--qkv-format bshd` in the 4B smoke runner. -- Reran with `--qkv-format bshd`: - - The TE dot-product attention backend error did not recur during initial log-prob/training. - - All 5 services registered successfully. - - Rollout generation completed the first 16 samples and transferred rollout batches. - - Actor training entered `MegatronTrainRayActor.train_async`. - - Run was still active at the time of this note. -- Adjusted `relax/backends/megatron/arguments.py` so `qkv_format=bshd` keeps `variable_seq_lengths=False`. -- Reran `bshd + variable_seq_lengths=False`: - - CLI args confirmed `qkv_format=bshd` and `variable_seq_lengths=False`. - - This run ended early with Ray/GCS disconnect (`Failed to connect to GCS within 60 seconds`) before producing a useful training-side result. - - GPU utilization returned to 0% after cleanup. -- Added temporary TransformerEngine attention diagnostics in the current container: - - `dot_product_attention/utils.py` now warns when backend selection ends in `NoBackend`, including `run_config` and backend candidates. - - `dot_product_attention/dot_product_attention.py` now logs q/k/v tensor metadata, qkv layout, mask type, sequence lengths, and selected backend flags before raising the `No dot product attention backend` error. - - The direct runner now defaults `NVTE_DEBUG=1`, `NVTE_DEBUG_LEVEL=2`, and `RAY_DEDUP_LOGS=0` so TE and Ray preserve the backend rejection details in logs. -- Added `amd/run_qwen3-4b.sh` as the one-command smoke runner: - - Kills stale Relax, Ray, SGLang, Megatron worker, and Ray dashboard/worker processes at startup. - - Recreates `/tmp/ray-qwen3-4b` and starts a fresh local Ray head on `10.235.26.199:6380`. - - Exports all Qwen3-4B DAPO-Math, ROCm TE debug, Ray, and Serve settings before invoking `amd/run-qwen3-4b-dapo-math-direct.sh`. - - Intended usage: edit variables inside the script, then run `bash amd/run_qwen3-4b.sh`. -- TE backend root cause from `qwen3-4b-dapo-math-te-debug-20260427-153023`: - - The launch used `--attention-backend flash`, and Megatron set `NVTE_FUSED_ATTN=0` plus `NVTE_UNFUSED_ATTN=0`. - - ROCm TE reported `flash_attn_version='not installed'`, so FlashAttention was unavailable. - - With fused/unfused forced off and flash missing, backend selection ended as `available_backends=[False, 0, 0]`. - - Changed the smoke runner to `--attention-backend auto` so TE can fall back to fused or unfused on ROCm. -- Reran with `--attention-backend auto`: - - TE selected `FusedAttention backend (sub-backend 1)`. - - Step 0 and step 1 logprob/training passed the previous attention blocker. - - New blocker at step 2: TorchDynamo fake tensor failure in Megatron `fused_cross_entropy.py` (`torch.split(..., SymInt)`). - - Added `TORCHDYNAMO_DISABLE=1`, `--disable-jit-fuser`, and `--train-env-vars '{"TORCHDYNAMO_DISABLE": "1"}'` to avoid the Dynamo compiled fused CE path for ROCm smoke. -- Reran with Dynamo/JIT fuser disabled: - - All 5 services registered successfully. - - TE selected `FusedAttention backend (sub-backend 1)`. - - Rollout, reference logprob, actor_fwd logprob, advantages, and actor training completed all 4 smoke steps. - - Checkpoint saved successfully at iteration 3 under `Qwen3-4B_mcore_4xgpu/`. - - Process exited with code 0. diff --git a/amd/qwen35-9b-dapo-math.md b/amd/qwen35-9b-dapo-math.md deleted file mode 100644 index 93c83564..00000000 --- a/amd/qwen35-9b-dapo-math.md +++ /dev/null @@ -1,195 +0,0 @@ -# Qwen3.5-9B DAPO-Math on AMD ROCm - -## Goal - -Validate that Relax can run Qwen3.5-9B DAPO-Math on AMD Instinct MI355X with the ROCm image. - -## Initial Plan - -1. Use the existing ROCm container; a single-node 8-GPU run does not need multiple Docker containers. -2. Activate the image environment from `/root/.bashrc`. -3. Verify ROCm, PyTorch, Ray, and Relax import paths. -4. Verify model and dataset assets. -5. Launch `scripts/training/text/run-qwen35-9B-8xgpu-async.sh`. -6. Record failures and fixes here as the source of truth for this validation. - -## Expected Assets - -The launch script expects `MODEL_DIR` to contain: - -```text -Qwen3.5-9B/ -dapo-math-17k/dapo-math-17k.jsonl -aime-2024/aime-2024.jsonl -Qwen3-9B_mcore_8xgpu/ # created or reused for checkpoints -``` - -## Environment Notes - -- ROCm GPU visibility should be checked with both `rocm-smi` and PyTorch. -- On this image, AMD devices are exposed through PyTorch's `torch.cuda` API. -- If non-interactive shells skip `/root/.bashrc` setup, run commands with `PS1` set or use `/opt/venv/bin/*` directly. - -## Experiment Log - -### 2026-04-26 - -- Created this AMD validation directory. -- Environment check passed: - - Python: `/opt/venv/bin/python` - - Ray: `/opt/venv/bin/ray`, version 2.55.1 - - PyTorch: 2.9.1 ROCm 7.2 - - GPUs: 8 x AMD Instinct MI355X visible through `torch.cuda` - - Relax import path: `/data/models/minimax/Relax/relax/__init__.py` -- Downloaded assets under `/data/models/minimax/Relax/amd/assets/exps`: - - `Qwen/Qwen3.5-9B` -> `Qwen3.5-9B/` - - `zhuzilin/dapo-math-17k` -> `dapo-math-17k/` - - `zhuzilin/aime-2024` -> `aime-2024/` -- Processed AIME in place with `scripts/tools/process_aime.py`. -- Asset check passed for model config/tokenizer and both dataset JSONL files. -- Launch run directory: `/data/models/minimax/Relax/amd/runs/qwen35-9b-dapo-math-20260426-153106` -- Launch command environment: - - `MODEL_DIR=/data/models/minimax/Relax/amd/assets/exps` - - `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7` - - `NUM_GPUS=8` - - `HOST_IP=127.0.0.1` - - `RAY_NO_WAIT=1` - - `RAY_TMPDIR=/ray_tmp` -- First launch failed before job submission: - - Failure: Ray plasma store Unix socket path exceeded the 107-byte AF_UNIX limit. - - Cause: `RAY_TMPDIR` under the deep `amd/runs/...` path made Ray's session socket path too long. - - Fix: keep logs in `amd/runs`, but use a short Ray temp path such as `/tmp/ray-q35-153106`. -- Second launch also failed before job submission: - - Failure: Ray GCS could not bind port `6379`. - - Evidence: Ray `gcs_server.err` reports `Address already in use` for `0.0.0.0:6379`. - - Container tools did not expose the owning PID, so this appears to be occupied outside this run. - - Fix: manually start Ray on a non-default GCS port (`6380`) and submit through `scripts/entrypoint/ray-job.sh`. -- Found an existing Ray cluster on `6379/8265`, but it only exposes 2 GPUs (`CUDA_VISIBLE_DEVICES=1,6`), so it cannot run this 8-GPU recipe. -- Started an independent Ray head: - - GCS: `10.235.26.199:6380` - - Dashboard/job server: `http://10.235.26.199:8266` - - Resources: 8 GPUs -- Patched the Qwen3.5-9B launch script so the Ray job server port can be overridden with `RAY_DASHBOARD_PORT`. -- First submission to the independent Ray dashboard failed because the dashboard agent tried to bind port `52365`, which was already used by the existing 2-GPU Ray cluster. -- Restarted the independent Ray head with explicit non-conflicting ports: - - GCS: `6380` - - dashboard: `8266` - - dashboard agent HTTP: `8267` - - dashboard agent gRPC: `8268` - - node manager: `6381` - - object manager: `6382` - - metrics: `6383` - - worker ports: `20000-20199` -- Verified `ray job list --address=http://10.235.26.199:8266` works and reports no jobs. -- Resubmitting through Ray Jobs still hung before a job appeared in `ray job list`. -- Stopped the stuck `ray job submit` process; no training job had been created. -- Added `amd/run-qwen35-9b-dapo-math-direct.sh` to bypass Ray Jobs and run the Relax driver directly against `RAY_ADDRESS=10.235.26.199:6380`. -- First direct driver run reached Relax/Megatron argument parsing, then failed because `CUDA_DEVICE_MAX_CONNECTIONS=1` was missing in the direct runner environment. -- Added `CUDA_DEVICE_MAX_CONNECTIONS=1` to the direct runner. -- Second direct driver run passed argument validation and connected to Ray, then failed during worker registration: - - Failure: `No available ports. Please specify a wider port range using --min-worker-port and --max-worker-port.` - - Cause: the independent Ray head was started with a too-narrow worker range (`20000-20199`). - - Fix: restart the independent Ray head with a wider worker range (`30000-65000`). -- Restarted Ray with the wider worker range and relaunched the direct driver. -- Third direct run progressed through: - - Ray initialization - - Ray Serve startup - - DCSCoordinator and MetricsService deployment - - 8-GPU resource validation - - placement group allocation for actor/rollout/reference/actor_fwd -- Third direct run failed during service deployment: - - Rollout/SGLang failure: `ModuleNotFoundError: No module named 'sgl_kernel'`. - - Actor/Megatron failure: `NotImplementedError: Experimental attention variant is not supported with local spec yet.` -- Megatron expert conclusion: - - Qwen3.5-9B uses GatedDeltaNet / experimental attention (`gated_delta_net`) plus attention output gate. - - Current ROCm image does not provide `transformer_engine`, `fla`, `causal_conv1d`, or `sgl_kernel`. - - Megatron local spec explicitly does not support `experimental_attention_variant`, so this is not a config-only failure. - - Removing `--use-gated-attention` or `--attention-output-gate` would invalidate the Qwen3.5 architecture and checkpoint shapes. - - Running Qwen3.5 on ROCm requires real backend enablement: ROCm-compatible TE/FLA/GDN support or local Megatron GDN implementation, plus SGLang ROCm kernel dependency handling. -- Cleanup: - - Stopped the independent 8-GPU Ray cluster on `10.235.26.199:6380`. - - The pre-existing 2-GPU Ray cluster on `127.0.0.1:6379` is still running. - - GPU utilization returned to 0%. - -## Current Conclusion - -Qwen3.5-9B DAPO-Math cannot currently be run to training on this ROCm image with only launch-script changes. The run reached Ray/Serve placement and model service initialization, then failed on missing Qwen3.5 backend support in Megatron/SGLang. - -For AMD validation now, use a supported Qwen3/Qwen3-MoE recipe first. For Qwen3.5 specifically, the next work item is model-backend integration rather than launch orchestration. - -## Follow-up Runner - -Added `amd/run_qwen35-9b.sh` as the one-command Qwen3.5-9B smoke runner: - -- Cleans stale Relax, Ray, SGLang, Megatron worker, and Ray dashboard/worker processes at startup. -- Recreates `/tmp/ray-qwen35-9b` and starts a fresh local Ray head on `10.235.26.199:6380` with 8 GPUs. -- Exports the ROCm TE debug settings validated in the Qwen3-4B run. -- Aligns the direct runner with the Qwen3-4B ROCm fixes: - - `--attention-backend auto` - - `TORCHDYNAMO_DISABLE=1` - - `--disable-jit-fuser` - - `--train-env-vars '{"TORCHDYNAMO_DISABLE": "1"}'` - -Intended usage: - -```bash -bash amd/run_qwen35-9b.sh -``` - -### 2026-04-27 Follow-up - -- Ran `bash amd/run_qwen35-9b.sh`. -- The wrapper correctly restarted Ray with 8 GPUs and launched the direct runner with: - - `attention_backend=auto` - - `disable_jit_fuser=True` - - `train_env_vars={'TORCHDYNAMO_DISABLE': '1'}` -- The run passed Ray/Serve startup and began deploying actor, rollout, reference, actor_fwd, and advantages services. -- New root blocker: - - `ImportError: FLA is not installed. Please install it with pip install flash-linear-attention.` - - Raised while instantiating Megatron `GatedDeltaNet`, then `TransformerLayer`. - - This confirms Qwen3.5-9B requires FLA/GatedDeltaNet backend support before training can proceed. -- Stopped the run after capturing the error to avoid Serve restart loops. -- Checked SGLang's ROCm Dockerfile: - - It installs AITER, SGLang ROCm extras, `sgl-kernel`, TileLang, and related ROCm serving dependencies. - - It does not explicitly install `flash-linear-attention`. - - The missing FLA error is from Megatron's `GatedDeltaNet`, not directly from SGLang. -- Installed `flash-linear-attention==0.5.0` in the current container and verified Megatron's required imports: - - `fla.modules.convolution.causal_conv1d` - - `fla.modules.l2norm.l2norm` - - `fla.ops.gated_delta_rule.chunk_gated_delta_rule` -- Updated `docker/Dockerfile.rocm` to install ROCm TransformerEngine 2.4.0 wheels and `flash-linear-attention==0.5.0` for future images. -- Reran after installing FLA: - - Megatron `GatedDeltaNet` initialization passed the previous FLA import blocker. - - Megatron actor/reference began loading the Qwen3.5 HF checkpoint. - - New rollout blocker: SGLang TP=2 tried to use device ordinal 1 while the Ray actor saw only `CUDA_VISIBLE_DEVICES='4'`. - - Changed the 9B smoke runner to `--rollout-num-gpus-per-engine 1` so the two rollout GPUs run as two 1-GPU SGLang engines for the next smoke attempt. -- Reran with two 1-GPU rollout engines: - - The invalid device ordinal issue did not recur. - - All 5 services registered successfully. - - Actor, rollout, reference, actor_fwd, and advantages entered step 0. - - TE selected `FusedAttention backend (sub-backend 1)` for Megatron logprob/training. - - New blocker: one SGLang engine hit HIP OOM in logits allocation while full token usage reached ~1.0. - - This was with the original full-style rollout settings (`num_rollout=1000`, `rollout_batch_size=32`, `response_len=8192`). - - Changed the default 9B runner to a smaller smoke profile: - - `NUM_ROLLOUT=4` - - `--num-iters-per-train-update 2` - - `--rollout-batch-size 2` - - `--rollout-max-response-len 2048` - - `--global-batch-size 16` -- Reran the smaller smoke profile: - - All 5 services registered successfully. - - Actor, rollout, reference, actor_fwd, and advantages entered step 0. - - Rollout data reached reference/actor_fwd/actor training paths. - - TE selected `FusedAttention backend (sub-backend 1)`. - - SGLang still hit HIP OOM in logits allocation during step-0 generation. -- Reduced the default smoke profile further for follow-up validation: - - `NUM_ROLLOUT=2` - - `--rollout-batch-size 1` - - `--n-samples-per-prompt 2` - - `--rollout-max-response-len 1024` - - `--global-batch-size 2` -- Reran the reduced smoke profile: - - Rollout generation succeeded and prepared a 2-sample rollout batch. - - New non-backend blocker: TransferQueue GRPO sampler rejected `batch_size=1` because it must be a multiple of `n_samples_per_prompt=2`. - - Changed `--global-batch-size` to `4` while keeping `n_samples_per_prompt=2`. -- Changed `--num-iters-per-train-update` to `1` so the actor training loop consumes a single rollout iteration in this tiny smoke profile. diff --git a/amd/run-qwen3-4b-dapo-math-direct.sh b/amd/run-qwen3-4b-dapo-math-direct.sh deleted file mode 100755 index 9388f872..00000000 --- a/amd/run-qwen3-4b-dapo-math-direct.sh +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env bash - -# Copyright (c) 2026 Relax Authors. All Rights Reserved. - -set -ex -set -o pipefail - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" -RUN_ID="${RUN_ID:-qwen3-4b-dapo-math-direct-$(date +%Y%m%d-%H%M%S)}" -RUN_DIR="${RUN_DIR:-${SCRIPT_DIR}/runs/${RUN_ID}}" - -mkdir -p "${RUN_DIR}" -cd "${RUN_DIR}" - -DEFAULT_MASTER_ADDR="$(hostname -I | awk '{print $1}')" - -export MODEL_DIR="${MODEL_DIR:-${SCRIPT_DIR}/assets/exps}" -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}" -export NUM_GPUS="${NUM_GPUS:-4}" -export MASTER_ADDR="${MASTER_ADDR:-${DEFAULT_MASTER_ADDR:-127.0.0.1}}" -export RAY_ADDRESS="${RAY_ADDRESS:-${MASTER_ADDR}:6380}" -export RELAX_SERVE_PORT="${RELAX_SERVE_PORT:-18081}" -export CUDA_DEVICE_MAX_CONNECTIONS="${CUDA_DEVICE_MAX_CONNECTIONS:-1}" -export NVTE_DEBUG="${NVTE_DEBUG:-1}" -export NVTE_DEBUG_LEVEL="${NVTE_DEBUG_LEVEL:-2}" -export RAY_DEDUP_LOGS="${RAY_DEDUP_LOGS:-0}" -export TORCHDYNAMO_DISABLE="${TORCHDYNAMO_DISABLE:-1}" -export MEGATRON="${MEGATRON:-/root/Megatron-LM/}" -export RELAX="${RELAX:-${REPO_ROOT}}" -export PYTHONPATH="${RELAX}:${MEGATRON}:${PYTHONPATH:-}" -export MODEL_CONFIG_DIR="${MODEL_CONFIG_DIR:-${REPO_ROOT}/scripts/models}" - -source "${MODEL_CONFIG_DIR}/qwen3-4B.sh" - -now=$(date "+%Y-%m-%d-%H:%M:%S") -PROJECT_NAME="${PROJECT_NAME:=Relax/dev/dapo-math}" -EXP_DIR="${MODEL_DIR}" -NUM_ROLLOUT="${NUM_ROLLOUT:=4}" - -CKPT_ARGS=( - --hf-checkpoint ${EXP_DIR}/Qwen3-4B/ - --ref-load ${EXP_DIR}/Qwen3-4B/ - --megatron-to-hf-mode bridge - --save ${EXP_DIR}/Qwen3-4B_mcore_4xgpu/ - --save-interval 100 -) - -PROMPT_SET=${EXP_DIR}/dapo-math-17k/dapo-math-17k.jsonl - -ROLLOUT_ARGS=( - --use-streaming-dataset - --streaming-buffer-size 10000 - --prompt-data ${PROMPT_SET} - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type dapo - --reward-key score - --num-rollout ${NUM_ROLLOUT} - --rollout-batch-size 2 - --n-samples-per-prompt 8 - --rollout-max-response-len 2048 - --rollout-temperature 0.8 - --global-batch-size 16 - --use-fault-tolerance -) - -PERF_ARGS=( - --tensor-model-parallel-size 1 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - --micro-batch-size 1 - --max-tokens-per-gpu 9216 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 - --use-tis -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 1 - --sglang-mem-fraction-static 0.8 -) - -TRACKING_ARGS=( - --tb-project-name ${PROJECT_NAME} - --tb-experiment-name qwen3-4b-4x-direct-${now} -) - -MISC_ARGS=( - --attention-dropout 0.0 - --hidden-dropout 0.0 - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - --qkv-format bshd - --attention-backend auto - --disable-jit-fuser - --train-env-vars '{"TORCHDYNAMO_DISABLE": "1"}' -) - -python3 -m relax.entrypoints.train \ - --resource '{"actor": [1, 1], "rollout": [1, 1], "reference": [1, 1], "actor_fwd": [1, 1], "advantages": [1, 0]}' \ - --max-staleness 1 \ - --num-data-storage-units 1 \ - --num-iters-per-train-update 2 \ - --ref-actor-config '{"tensor_model_parallel_size": 1, "max_tokens_per_gpu": 9216, "sequence_parallel": false, "only_load_weight": true}' \ - --fully-async \ - --use-health-check \ - "${MODEL_ARGS[@]}" \ - "${CKPT_ARGS[@]}" \ - "${ROLLOUT_ARGS[@]}" \ - "${OPTIMIZER_ARGS[@]}" \ - "${GRPO_ARGS[@]}" \ - "${TRACKING_ARGS[@]}" \ - "${PERF_ARGS[@]}" \ - "${SGLANG_ARGS[@]}" \ - "${MISC_ARGS[@]}" 2>&1 | tee "direct-train-${now}.log" diff --git a/amd/run-qwen35-9b-dapo-math-direct.sh b/amd/run-qwen35-9b-dapo-math-direct.sh index bff89dd2..f8782be0 100755 --- a/amd/run-qwen35-9b-dapo-math-direct.sh +++ b/amd/run-qwen35-9b-dapo-math-direct.sh @@ -16,12 +16,28 @@ cd "${RUN_DIR}" DEFAULT_MASTER_ADDR="$(hostname -I | awk '{print $1}')" export MODEL_DIR="${MODEL_DIR:-${SCRIPT_DIR}/assets/exps}" -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}" +export HF_MODEL_PATH="${HF_MODEL_PATH:-Qwen/Qwen3.5-9B}" +export HF_MODEL_DIR="${HF_MODEL_DIR:-${SCRIPT_DIR}/assets/hf-models}" +export HF_TRAIN_DATASET_PATH="${HF_TRAIN_DATASET_PATH:-zhuzilin/dapo-math-17k/dapo-math-17k.jsonl}" +export HF_EVAL_DATASET_PATH="${HF_EVAL_DATASET_PATH:-zhuzilin/aime-2024/aime-2024.jsonl}" +export HF_DATASET_DIR="${HF_DATASET_DIR:-${SCRIPT_DIR}/assets/hf-datasets}" +if command -v rocm-smi >/dev/null 2>&1; then + export HIP_VISIBLE_DEVICES="${HIP_VISIBLE_DEVICES:-${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}}" + unset CUDA_VISIBLE_DEVICES + unset ROCR_VISIBLE_DEVICES +else + export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}" + unset HIP_VISIBLE_DEVICES + unset ROCR_VISIBLE_DEVICES +fi export NUM_GPUS="${NUM_GPUS:-8}" -export MASTER_ADDR="${MASTER_ADDR:-${DEFAULT_MASTER_ADDR:-127.0.0.1}}" -export RAY_ADDRESS="${RAY_ADDRESS:-${MASTER_ADDR}:6380}" -export RELAX_SERVE_PORT="${RELAX_SERVE_PORT:-18080}" +export MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}" +export QWEN35_SOCKET_IFNAME="${QWEN35_SOCKET_IFNAME:-eth0}" +export GLOO_SOCKET_IFNAME="${GLOO_SOCKET_IFNAME:-${QWEN35_SOCKET_IFNAME}}" +export RAY_ADDRESS="${RAY_ADDRESS:-${MASTER_ADDR}:6379}" +export RELAX_SERVE_PORT="${RELAX_SERVE_PORT:-8000}" export CUDA_DEVICE_MAX_CONNECTIONS="${CUDA_DEVICE_MAX_CONNECTIONS:-1}" +export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES="${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-1}" export NVTE_DEBUG="${NVTE_DEBUG:-1}" export NVTE_DEBUG_LEVEL="${NVTE_DEBUG_LEVEL:-2}" export RAY_DEDUP_LOGS="${RAY_DEDUP_LOGS:-0}" @@ -30,26 +46,87 @@ export MEGATRON="${MEGATRON:-/root/Megatron-LM/}" export RELAX="${RELAX:-${REPO_ROOT}}" export PYTHONPATH="${RELAX}:${MEGATRON}:${PYTHONPATH:-}" export MODEL_CONFIG_DIR="${MODEL_CONFIG_DIR:-${REPO_ROOT}/scripts/models}" +if [ -z "${RELAX_RESOURCE:-}" ]; then + RELAX_RESOURCE="$( + python3 - <<'PY' +import json + +print( + json.dumps( + { + "actor": [1, 4], + "rollout": [1, 2], + "reference": [1, 1], + "actor_fwd": [1, 1], + "advantages": [1, 0], + } + ) +) +PY + )" + export RELAX_RESOURCE +fi source "${MODEL_CONFIG_DIR}/qwen35-9B.sh" +resolve_hf_model_path() { + local repo_id="$1" + local repo_name="${repo_id##*/}" + local local_dir="${HF_MODEL_DIR}/${repo_name}" + + if [ ! -f "${local_dir}/config.json" ]; then + mkdir -p "${local_dir}" + hf download "${repo_id}" --local-dir "${local_dir}" >&2 + fi + + if [ ! -f "${local_dir}/config.json" ]; then + echo "HF model config was not found after download: ${repo_id} -> ${local_dir}" >&2 + exit 1 + fi + + printf "%s" "${local_dir}" +} + +resolve_hf_dataset_file() { + local repo_file="$1" + local repo_id="${repo_file%/*}" + local filename="${repo_file##*/}" + local repo_name="${repo_id##*/}" + local local_dir="${HF_DATASET_DIR}/${repo_name}" + local local_file="${local_dir}/${filename}" + + if [ ! -f "${local_file}" ]; then + mkdir -p "${local_dir}" + hf download --repo-type dataset "${repo_id}" --include "${filename}" --local-dir "${local_dir}" >&2 + fi + + if [ ! -f "${local_file}" ]; then + echo "HF dataset file was not found after download: ${repo_file} -> ${local_file}" >&2 + exit 1 + fi + + printf "%s" "${local_file}" +} + now=$(date "+%Y-%m-%d-%H:%M:%S") PROJECT_NAME="${PROJECT_NAME:=Relax/dev/dapo-math}" EXP_DIR="${MODEL_DIR}" -NUM_ROLLOUT="${NUM_ROLLOUT:=2}" +NUM_ROLLOUT="${NUM_ROLLOUT:=1000}" +MODEL_PATH="$(resolve_hf_model_path "${HF_MODEL_PATH}")" +PROMPT_SET="$(resolve_hf_dataset_file "${HF_TRAIN_DATASET_PATH}")" +EVAL_PROMPT_SET="$(resolve_hf_dataset_file "${HF_EVAL_DATASET_PATH}")" +LOCAL_CKPT_DIR="${LOCAL_CKPT_DIR:-${EXP_DIR}/Qwen3-9B_mcore_8xgpu}" CKPT_ARGS=( - --hf-checkpoint ${EXP_DIR}/Qwen3.5-9B - --ref-load ${EXP_DIR}/Qwen3.5-9B + --hf-checkpoint ${MODEL_PATH} + --ref-load ${MODEL_PATH} --megatron-to-hf-mode bridge - --load ${EXP_DIR}/Qwen3-9B_mcore_8xgpu/ - --save ${EXP_DIR}/Qwen3-9B_mcore_8xgpu/ + --load ${LOCAL_CKPT_DIR}/ + --save ${LOCAL_CKPT_DIR}/ --save-interval 50 --max-actor-ckpt-to-keep 1 ) -PROMPT_SET=${EXP_DIR}/dapo-math-17k/dapo-math-17k.jsonl - ROLLOUT_ARGS=( --prompt-data ${PROMPT_SET} --input-key prompt @@ -65,20 +142,22 @@ ROLLOUT_ARGS=( --rollout-temperature 1 --global-batch-size 4 --use-fault-tolerance + --partial-rollout + --partial-rollout-max-aborted-count 3 ) EVAL_ARGS=( --log-passrate --skip-eval-before-train --eval-interval 20 - --eval-prompt-data aime ${EXP_DIR}/aime-2024/aime-2024.jsonl + --eval-prompt-data aime ${EVAL_PROMPT_SET} --n-samples-per-eval-prompt 8 --eval-max-response-len 8192 --eval-top-p 0.7 ) PERF_ARGS=( - --tensor-model-parallel-size 4 + --tensor-model-parallel-size ${ACTOR_TP:-4} --sequence-parallel --pipeline-model-parallel-size 1 --context-parallel-size 1 @@ -116,7 +195,8 @@ OPTIMIZER_ARGS=( ) SGLANG_ARGS=( - --rollout-num-gpus-per-engine 1 + --rollout-num-gpus-per-engine ${ROLLOUT_NUM_GPUS_PER_ENGINE:-${ROLLOUT_GPUS:-2}} + --sglang-router-policy round_robin --sglang-mem-fraction-static 0.8 --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) ) @@ -124,6 +204,7 @@ SGLANG_ARGS=( WANDB_ARGS=( --use-clearml --use-metrics-service + --timeline-dump-dir ${RUN_DIR}/timeline --tb-project-name ${PROJECT_NAME} --tb-experiment-name qwen35-9B-8x-direct-${now} ) @@ -134,12 +215,21 @@ MISC_ARGS=( --accumulate-allreduce-grads-in-fp32 --attention-softmax-in-fp32 --attention-backend auto - --disable-jit-fuser - --train-env-vars '{"TORCHDYNAMO_DISABLE": "1"}' ) -python3 -m relax.entrypoints.train \ - --resource '{"actor": [1, 4], "rollout": [1, 2], "reference": [1, 1], "actor_fwd": [1, 1], "advantages": [1, 0]}' \ +mkdir -p log +RAY_JOB_ARGS=() +if [ -n "${WORKING_DIR:-}" ]; then + RAY_JOB_ARGS+=(--working-dir "${WORKING_DIR}") +fi +if [ -n "${RUNTIME_ENV_JSON:-}" ]; then + RAY_JOB_ARGS+=(--runtime-env-json="${RUNTIME_ENV_JSON}") +fi + +ray job submit ${RAY_NO_WAIT:+--no-wait} --address="http://${HOST_IP:-${MASTER_ADDR}}:${RAY_DASHBOARD_PORT:-8265}" \ + "${RAY_JOB_ARGS[@]}" \ + -- python3 -m relax.entrypoints.train \ + --resource "${RELAX_RESOURCE}" \ --max-staleness 2 \ --num-data-storage-units 1 \ --num-iters-per-train-update 1 \ @@ -155,4 +245,4 @@ python3 -m relax.entrypoints.train \ "${PERF_ARGS[@]}" \ "${EVAL_ARGS[@]}" \ "${SGLANG_ARGS[@]}" \ - "${MISC_ARGS[@]}" 2>&1 | tee "direct-train-${now}.log" + "${MISC_ARGS[@]}" 2>&1 | tee "log/qwen35-9B-GRPO-gpu16-async-${now}.log" diff --git a/amd/run_qwen3-4b.sh b/amd/run_qwen3-4b.sh deleted file mode 100755 index 0883d0b2..00000000 --- a/amd/run_qwen3-4b.sh +++ /dev/null @@ -1,104 +0,0 @@ -#!/usr/bin/env bash - -# Copyright (c) 2026 Relax Authors. All Rights Reserved. - -set -Eeuo pipefail - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" -DEFAULT_MASTER_ADDR="$(hostname -I | awk '{print $1}')" - -# Edit this block directly when changing the smoke configuration. -export MODEL_DIR="${SCRIPT_DIR}/assets/exps" -export CUDA_VISIBLE_DEVICES="0,1,2,3" -export NUM_GPUS="4" -export MASTER_ADDR="${MASTER_ADDR:-${DEFAULT_MASTER_ADDR:-127.0.0.1}}" -export RAY_PORT="6380" -export RAY_DASHBOARD_PORT="8266" -export RAY_MIN_WORKER_PORT="30000" -export RAY_MAX_WORKER_PORT="65000" -export RAY_NODE_MANAGER_PORT="6381" -export RAY_OBJECT_MANAGER_PORT="6382" -export RAY_RUNTIME_ENV_AGENT_PORT="6383" -export RAY_DASHBOARD_AGENT_LISTEN_PORT="6384" -export RAY_DASHBOARD_AGENT_GRPC_PORT="6385" -export RAY_TMPDIR="/tmp/ray-qwen3-4b" -export RELAX_SERVE_PORT="18081" -export MEGATRON="/root/Megatron-LM/" -export RELAX="${REPO_ROOT}" -export RUN_ID="qwen3-4b-dapo-math-te-debug-$(date +%Y%m%d-%H%M%S)" - -# Runtime diagnostics for the current ROCm TransformerEngine backend issue. -export CUDA_DEVICE_MAX_CONNECTIONS="1" -export NVTE_DEBUG="1" -export NVTE_DEBUG_LEVEL="2" -export RAY_DEDUP_LOGS="0" -export TORCHDYNAMO_DISABLE="1" -export PYTHONUNBUFFERED="1" - -export RAY_ADDRESS="${MASTER_ADDR}:${RAY_PORT}" -export PYTHONPATH="${RELAX}:${MEGATRON}:${PYTHONPATH:-}" - -cleanup_stale_processes() { - echo "=== Cleaning stale Relax/Ray/SGLang processes ===" - timeout 30 ray stop --force 2>/dev/null || true - - pkill -9 -f "relax.entrypoints.train" 2>/dev/null || true - pkill -9 -f "sglang" 2>/dev/null || true - pkill -9 -f "MegatronTrainRayActor" 2>/dev/null || true - pkill -9 -f "SGLang" 2>/dev/null || true - pkill -9 -f "ray::" 2>/dev/null || true - pkill -9 -f "raylet" 2>/dev/null || true - pkill -9 -f "gcs_server" 2>/dev/null || true - pkill -9 -f "default_worker.py" 2>/dev/null || true - pkill -9 -f "runtime_env_agent" 2>/dev/null || true - pkill -9 -f "dashboard_agent" 2>/dev/null || true - pkill -9 -f "dashboard.py" 2>/dev/null || true - pkill -9 -f "log_monitor.py" 2>/dev/null || true - pkill -9 -f "monitor.py" 2>/dev/null || true - - sleep 3 - - if [[ "${RAY_TMPDIR}" == /tmp/ray-qwen3-4b* ]]; then - rm -rf "${RAY_TMPDIR}" - fi - mkdir -p "${RAY_TMPDIR}" -} - -start_ray_head() { - echo "=== Starting Ray head ${RAY_ADDRESS} with ${NUM_GPUS} GPUs ===" - ray start --head \ - --node-ip-address="${MASTER_ADDR}" \ - --port="${RAY_PORT}" \ - --num-gpus="${NUM_GPUS}" \ - --temp-dir="${RAY_TMPDIR}" \ - --disable-usage-stats \ - --include-dashboard=true \ - --dashboard-host=0.0.0.0 \ - --dashboard-port="${RAY_DASHBOARD_PORT}" \ - --node-manager-port="${RAY_NODE_MANAGER_PORT}" \ - --object-manager-port="${RAY_OBJECT_MANAGER_PORT}" \ - --runtime-env-agent-port="${RAY_RUNTIME_ENV_AGENT_PORT}" \ - --dashboard-agent-listen-port="${RAY_DASHBOARD_AGENT_LISTEN_PORT}" \ - --dashboard-agent-grpc-port="${RAY_DASHBOARD_AGENT_GRPC_PORT}" \ - --min-worker-port="${RAY_MIN_WORKER_PORT}" \ - --max-worker-port="${RAY_MAX_WORKER_PORT}" - - for _ in $(seq 1 30); do - if ray status --address="${RAY_ADDRESS}" >/dev/null 2>&1; then - ray status --address="${RAY_ADDRESS}" - return 0 - fi - sleep 1 - done - - echo "Ray did not become ready at ${RAY_ADDRESS}" >&2 - return 1 -} - -cleanup_stale_processes -start_ray_head - -echo "=== Launching Qwen3-4B DAPO-Math direct runner ===" -cd "${REPO_ROOT}" -exec bash "${SCRIPT_DIR}/run-qwen3-4b-dapo-math-direct.sh" diff --git a/amd/run_qwen35-9b.sh b/amd/run_qwen35-9b.sh index 4204a630..ee07abb6 100755 --- a/amd/run_qwen35-9b.sh +++ b/amd/run_qwen35-9b.sh @@ -1,20 +1,123 @@ #!/usr/bin/env bash - +set -x +pkill -9 gcs_server +pkill -9 python +pkill -9 python3 +pkill -9 ray # Copyright (c) 2026 Relax Authors. All Rights Reserved. +sleep 5 set -Eeuo pipefail +ulimit -n 1048576 +pkill -9 python 2>/dev/null || true +pkill -9 python3 2>/dev/null || true + SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" -DEFAULT_MASTER_ADDR="$(hostname -I | awk '{print $1}')" +source "${REPO_ROOT}/scripts/entrypoint/device_env.sh" + +select_available_gpus() { + local min_free_vram_gb="$1" + local requested_gpus="${2:-}" + + python3 - "${min_free_vram_gb}" "${requested_gpus}" <<'PY' +import os +import sys + +threshold_gb = float(sys.argv[1]) +threshold = threshold_gb * 1024**3 +requested = int(sys.argv[2]) if sys.argv[2] else None +supported_counts = sorted( + int(item) for item in os.environ.get("QWEN35_SUPPORTED_GPU_COUNTS", "4,6,8").split(",") if item +) +if requested is not None and requested not in supported_counts: + print( + f"[qwen35-gpu-select] Requested {requested} GPUs, but supported profiles are {supported_counts}", + file=sys.stderr, + ) + sys.exit(2) + +for key in ("CUDA_VISIBLE_DEVICES", "HIP_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES"): + os.environ.pop(key, None) + +import torch + +if not torch.cuda.is_available(): + print("[qwen35-gpu-select] torch.cuda is not available", file=sys.stderr) + sys.exit(2) + +selected = [] +for index in range(torch.cuda.device_count()): + free_bytes, total_bytes = torch.cuda.mem_get_info(index) + free_gb = free_bytes / 1024**3 + total_gb = total_bytes / 1024**3 + print( + f"[qwen35-gpu-select] GPU {index}: free={free_gb:.1f}GB total={total_gb:.1f}GB", + file=sys.stderr, + ) + if free_bytes >= threshold: + selected.append(str(index)) + +if requested is None: + target = max((count for count in supported_counts if count <= len(selected)), default=0) +else: + target = requested + +if target <= 0: + print( + f"[qwen35-gpu-select] Need at least {supported_counts[0]} GPUs with >= {threshold_gb:.1f}GB free, " + f"but only found {len(selected)}: {','.join(selected) or ''}", + file=sys.stderr, + ) + sys.exit(2) + +if len(selected) < target: + print( + f"[qwen35-gpu-select] Need {target} GPUs with >= {threshold_gb:.1f}GB free, " + f"but only found {len(selected)}: {','.join(selected) or ''}", + file=sys.stderr, + ) + sys.exit(2) + +print( + f"[qwen35-gpu-select] Selected {target} GPUs from {len(selected)} eligible GPUs " + f"(supported={supported_counts})", + file=sys.stderr, +) +print(",".join(selected[:target])) +PY +} # Edit this block directly when changing the smoke configuration. export MODEL_DIR="${SCRIPT_DIR}/assets/exps" -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" -export NUM_GPUS="8" -export MASTER_ADDR="${MASTER_ADDR:-${DEFAULT_MASTER_ADDR:-127.0.0.1}}" -export RAY_PORT="6380" -export RAY_DASHBOARD_PORT="8266" +export HF_MODEL_PATH="${HF_MODEL_PATH:-Qwen/Qwen3.5-9B}" +export HF_MODEL_DIR="${HF_MODEL_DIR:-${SCRIPT_DIR}/assets/hf-models}" +export HF_TRAIN_DATASET_PATH="${HF_TRAIN_DATASET_PATH:-zhuzilin/dapo-math-17k/dapo-math-17k.jsonl}" +export HF_EVAL_DATASET_PATH="${HF_EVAL_DATASET_PATH:-zhuzilin/aime-2024/aime-2024.jsonl}" +export QWEN35_SUPPORTED_GPU_COUNTS="${QWEN35_SUPPORTED_GPU_COUNTS:-4,8}" +export QWEN35_MIN_FREE_VRAM_GB="${QWEN35_MIN_FREE_VRAM_GB:-150}" +if [ -n "${QWEN35_VISIBLE_DEVICES:-}" ]; then + SELECTED_GPUS="${QWEN35_VISIBLE_DEVICES}" +else + if ! SELECTED_GPUS="$(select_available_gpus "${QWEN35_MIN_FREE_VRAM_GB}" "${QWEN35_NUM_GPUS:-}")"; then + echo "=== Not enough free GPUs for Qwen3.5 smoke; exiting without starting Ray ===" + exit 0 + fi +fi +export HIP_VISIBLE_DEVICES="${SELECTED_GPUS}" +unset CUDA_VISIBLE_DEVICES +unset ROCR_VISIBLE_DEVICES +export NUM_GPUS="$(python3 - <<'PY' +import os +print(len([x for x in os.environ["HIP_VISIBLE_DEVICES"].split(",") if x])) +PY +)" +export MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}" +export QWEN35_SOCKET_IFNAME="${QWEN35_SOCKET_IFNAME:-eth0}" +export GLOO_SOCKET_IFNAME="${GLOO_SOCKET_IFNAME:-${QWEN35_SOCKET_IFNAME}}" +export RAY_PORT="6379" +export RAY_DASHBOARD_PORT="8265" export RAY_MIN_WORKER_PORT="30000" export RAY_MAX_WORKER_PORT="65000" export RAY_NODE_MANAGER_PORT="6381" @@ -23,13 +126,52 @@ export RAY_RUNTIME_ENV_AGENT_PORT="6383" export RAY_DASHBOARD_AGENT_LISTEN_PORT="6384" export RAY_DASHBOARD_AGENT_GRPC_PORT="6385" export RAY_TMPDIR="/tmp/ray-qwen35-9b" -export RELAX_SERVE_PORT="18080" +export RELAX_SERVE_PORT="8000" +export TENSORBOARD_DIR="${SCRIPT_DIR}/tensorboard/qwen35-9b" export MEGATRON="/root/Megatron-LM/" export RELAX="${REPO_ROOT}" export RUN_ID="qwen35-9b-dapo-math-te-debug-$(date +%Y%m%d-%H%M%S)" +if [ "${NUM_GPUS}" -ge 8 ]; then + export ACTOR_GPUS="${ACTOR_GPUS:-4}" + export ACTOR_TP="${ACTOR_TP:-4}" + export ROLLOUT_GPUS="${ROLLOUT_GPUS:-2}" + export ROLLOUT_NUM_GPUS_PER_ENGINE="${ROLLOUT_NUM_GPUS_PER_ENGINE:-1}" +else + export ACTOR_GPUS="${ACTOR_GPUS:-1}" + export ACTOR_TP="${ACTOR_TP:-1}" + export ROLLOUT_GPUS="${ROLLOUT_GPUS:-1}" + export ROLLOUT_NUM_GPUS_PER_ENGINE="${ROLLOUT_NUM_GPUS_PER_ENGINE:-1}" +fi +export REFERENCE_GPUS="${REFERENCE_GPUS:-1}" +export ACTOR_FWD_GPUS="${ACTOR_FWD_GPUS:-1}" +if [ -n "${QWEN35_RESOURCE:-}" ]; then + export RELAX_RESOURCE="${QWEN35_RESOURCE}" +else + RELAX_RESOURCE="$( + python3 - <<'PY' +import json +import os + +resource = { + "actor": [1, int(os.environ["ACTOR_GPUS"])], + "rollout": [1, int(os.environ["ROLLOUT_GPUS"])], + "reference": [1, int(os.environ["REFERENCE_GPUS"])], + "actor_fwd": [1, int(os.environ["ACTOR_FWD_GPUS"])], + "advantages": [1, 0], +} +print(json.dumps(resource)) +PY + )" + export RELAX_RESOURCE +fi + +echo "=== Selected GPUs: HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES} ===" +echo "=== Resource plan: NUM_GPUS=${NUM_GPUS}, ACTOR_GPUS=${ACTOR_GPUS}, ACTOR_TP=${ACTOR_TP}, ROLLOUT_GPUS=${ROLLOUT_GPUS}, ROLLOUT_NUM_GPUS_PER_ENGINE=${ROLLOUT_NUM_GPUS_PER_ENGINE}, REFERENCE_GPUS=${REFERENCE_GPUS}, ACTOR_FWD_GPUS=${ACTOR_FWD_GPUS} ===" + # Keep the 9B smoke aligned with the verified 4B ROCm settings. export CUDA_DEVICE_MAX_CONNECTIONS="1" +export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES="1" export NVTE_DEBUG="1" export NVTE_DEBUG_LEVEL="2" export RAY_DEDUP_LOGS="0" @@ -38,6 +180,37 @@ export PYTHONUNBUFFERED="1" export RAY_ADDRESS="${MASTER_ADDR}:${RAY_PORT}" export PYTHONPATH="${RELAX}:${MEGATRON}:${PYTHONPATH:-}" +export HOST_IP="${MASTER_ADDR}" +export HAS_NVLINK="$(relax_detect_fast_interconnect)" + +build_bindlog_preload() { + if [ "${BINDLOG_ENABLE:-1}" != "1" ]; then + return 0 + fi + + echo "=== Building bindlog LD_PRELOAD hook inside zty_relax ===" + bash -lc \ + "cd /B/Relax && gcc -shared -fPIC -O2 -Wall -Wextra -o tools/libbindlog.so tools/bindlog.c -ldl -pthread" +} + +build_bindlog_preload +export BINDLOG_PRELOAD_PATH="${BINDLOG_PRELOAD_PATH:-${REPO_ROOT}/tools/libbindlog.so}" +export BINDLOG_LD_PRELOAD="${BINDLOG_PRELOAD_PATH}${LD_PRELOAD:+:${LD_PRELOAD}}" +echo "=== bindlog enabled: LD_PRELOAD=${BINDLOG_LD_PRELOAD}===" + +export RUNTIME_ENV_JSON="{ +\"env_vars\": { + \"PYTHONUNBUFFERED\": \"1\", + \"PYTHONPATH\": \"${PYTHONPATH}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"RELAX_SERVE_PORT\": \"${RELAX_SERVE_PORT}\", + \"HIP_VISIBLE_DEVICES\": \"${HIP_VISIBLE_DEVICES}\", + \"RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES\": \"${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES}\", + \"RAY_OVERRIDE_JOB_RUNTIME_ENV\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"LD_PRELOAD\": \"${BINDLOG_LD_PRELOAD}\" +} +}" cleanup_stale_processes() { echo "=== Cleaning stale Relax/Ray/SGLang processes ===" diff --git a/docker/Dockerfile b/docker/Dockerfile index c5d4d9f0..247b436a 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -57,18 +57,11 @@ RUN MAX_JOBS=64 \ cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py && \ rm -rf /opt/flash-attention/ -ARG MEGATRON_COMMIT=3714d81d418c9f1bca4594fc35f9e8289f652862 - -RUN pip -v install --no-cache-dir --no-build-isolation "transformer_engine[pytorch]==2.10.0" && \ - cd /root && git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \ - cd Megatron-LM && \ - git checkout ${MEGATRON_COMMIT} && \ - pip install -e . --no-deps && \ +RUN pip -v install --no-cache-dir --no-build-isolation "transformer_engine[pytorch]==2.14.1" && \ pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@dc6876905830430b5054325fa4211ff302169c6b --no-cache-dir --force-reinstall && \ pip install nvidia-modelopt[torch]>=0.37.0 --no-build-isolation --no-cache-dir && \ - pip install "numpy<2" nvidia-cudnn-cu12==9.16.0.29 --no-cache-dir - -RUN NVCC_APPEND_FLAGS="--threads 4" \ + pip install "numpy<2" nvidia-cudnn-cu12==9.16.0.29 --no-cache-dir && \ + NVCC_APPEND_FLAGS="--threads 32" \ pip -v install --disable-pip-version-check --no-cache-dir \ --no-build-isolation \ --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 8" \ @@ -82,19 +75,27 @@ ARG ENABLE_SGLANG_PATCH=1 WORKDIR /root +ARG MEGATRON_BRIDGE_COMMIT=2faedbf6fe3c422835a44b2b360cadcb2a116a54 +ENV MEGATRON_BRIDGE_COMMIT=${MEGATRON_BRIDGE_COMMIT} \ + PYTHONPATH=/root/Megatron-LM/ + +RUN rm -rf /root/Megatron-LM && git clone https://github.com/NVIDIA-NeMo/Megatron-Bridge.git && \ + cd /root/Megatron-Bridge/ && git checkout ${MEGATRON_BRIDGE_COMMIT} && \ + git submodule update --init --recursive && ./scripts/switch_mcore.sh dev && \ + mkdir /root/Megatron-LM &&\ + cp -r /root/Megatron-Bridge/src/megatron /root/Megatron-LM/ && \ + rsync -avP /root/Megatron-Bridge/3rdparty/Megatron-LM/megatron/ /root/Megatron-LM/megatron/ && \ + rm -rf /root/Megatron-Bridge + COPY requirements.txt /tmp/requirements.txt RUN pip install -r /tmp/requirements.txt --no-cache-dir && \ pip install --no-cache-dir tensordict==0.10.0 pyvers==0.1.0 --no-deps && \ - apt-get install -y jq - -RUN pip install git+https://github.com/redai-infra/megatron-bridge.git@f13bec09 --no-build-isolation --no-deps --force-reinstall --no-cache-dir && \ pip install "transferqueue @ git+https://github.com/redai-infra/TransferQueue.git" --no-deps COPY docker/patch/${PATCH_VERSION}/megatron.patch /root/Megatron-LM/ RUN cd Megatron-LM && \ - git update-index --refresh && \ - git apply megatron.patch --3way && \ + patch -p1 < /root/Megatron-LM/megatron.patch && \ if grep -R -n '^<<<<<<< ' .; then \ echo "Patch failed to apply cleanly. Please resolve conflicts." && \ exit 1; \ diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index b76b9611..2edde3e8 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -2,12 +2,17 @@ ARG HTTP_PROXY ARG HTTPS_PROXY ARG NO_PROXY ARG BASE_IMAGE=rocm/pytorch:rocm7.2_ubuntu22.04_py3.10_pytorch_release_2.9.1 -ARG GPU_ARCH=gfx950 +ARG GPU_ARCH=gfx942 +ARG PYTORCH_ROCM_ARCH="gfx942;gfx950" +ARG AITER_REPO=https://github.com/ROCm/aiter.git +ARG AITER_COMMIT=v0.1.11.post1 +ARG AITER_GPU_ARCHS="gfx942;gfx950" ARG TRAIN_IMAGE=train FROM ${BASE_IMAGE} AS base ARG GPU_ARCH +ARG PYTORCH_ROCM_ARCH ENV PYVER=3.12 \ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python \ @@ -17,7 +22,7 @@ ENV PYVER=3.12 \ DEBIAN_FRONTEND=noninteractive \ PIP_NO_CACHE_DIR=1 \ MAX_JOBS=64 \ - PYTORCH_ROCM_ARCH=${GPU_ARCH} \ + PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} \ SGLANG_DISABLE_CUDNN_CHECK=1 \ HIP_FORCE_DEV_KERNARG=1 \ HSA_NO_SCRATCH_RECLAIM=1 \ @@ -45,6 +50,9 @@ RUN mkdir -p /sgl-workspace && \ FROM base AS sglang ARG GPU_ARCH +ARG AITER_REPO +ARG AITER_COMMIT +ARG AITER_GPU_ARCHS WORKDIR /sgl-workspace/sglang @@ -57,6 +65,19 @@ RUN python -m pip install --upgrade pip setuptools wheel setuptools_scm scikit-b SGL_KERNEL_ARCH="${GPU_ARCH%%-*}" && \ AMDGPU_TARGET="${SGL_KERNEL_ARCH}" python setup_rocm.py install +# AITER kernels are needed by SGLang's ROCm runtime. Build both MI300 and +# MI350 targets by default so one image can run on gfx942 and gfx950 hosts. +RUN python -m pip uninstall -y aiter amd-aiter && \ + python -m pip install flydsl==0.0.1.dev95158637 psutil pybind11 && \ + cd /sgl-workspace && \ + git clone "${AITER_REPO}" aiter && \ + cd aiter && \ + git checkout "${AITER_COMMIT}" && \ + git submodule update --init --recursive && \ + sed -i '459 s/if.*:/if False:/' aiter/ops/triton/attention/pa_mqa_logits.py && \ + sed -i '/c1 = torch.empty((M, D, S1 + S3), dtype=dtype, device=x.device)/i\ config = dict(config)' aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_split_cat.py && \ + PREBUILD_KERNELS=1 GPU_ARCHS="${AITER_GPU_ARCHS}" python setup.py develop + # SGLang keeps ROCm-specific dependency metadata in pyproject_other.toml. # Installing the CUDA-default pyproject would pull cu12 wheels into the image. RUN cp python/pyproject_other.toml python/pyproject.toml && \ @@ -130,6 +151,7 @@ RUN if [ "$ENABLE_SGLANG_PATCH" = "1" ]; then \ rm sglang.patch; \ fi -WORKDIR /root/Relax -COPY . . -RUN pip install -e . --no-deps +# TODO(zty): remove this comment after testing +# WORKDIR /root/ +# COPY . . +# RUN pip install -e . --no-deps diff --git a/docker/README.md b/docker/README.md index cb9bc410..729dd4a0 100644 --- a/docker/README.md +++ b/docker/README.md @@ -36,40 +36,28 @@ All models are tested in both modes to ensure stability and compatibility across ## Experimental ROCm Build -For AMD Instinct MI355/MI350 class systems, use `docker/Dockerfile.rocm`. +For AMD Instinct MI300X / MI355 / MI350 systems, use `docker/Dockerfile.rocm` +(base image `rocm/pytorch:rocm7.2_ubuntu22.04_py3.10_pytorch_release_2.9.1`). -This image is intentionally separate from the CUDA path because Relax's default image depends on NVIDIA-only packages such as Apex and `nvidia-modelopt`. The ROCm Dockerfile uses an AMD base image and keeps the same Relax patch flow for Megatron-LM and SGLang. - -Build the ROCm training image with: +Build with the helper script — defaults to `gfx942` (MI300X), set `GPU_ARCH` +for MI355/MI350: ```bash -DOCKER_BUILDKIT=1 docker build \ - -f docker/Dockerfile.rocm \ - --target relax \ - -t relax:rocm-relax-smoke \ - . +# MI300X (default) +docker/build-rocm.sh + +# MI355 / MI350 +GPU_ARCH=gfx950 docker/build-rocm.sh ``` -The current ROCm Dockerfile is validated on MI355/MI350 systems with -`rocm/pytorch:rocm7.2_ubuntu22.04_py3.10_pytorch_release_2.9.1` as the base -image. +The image is tagged `relax:rocm-${GPU_ARCH}`. AITER is prebuilt for both +archs, so a `gfx942` image still runs on `gfx950` hosts (just slower). -For day-to-day development, start a bind-mounted container so `/root/Relax` -inside the container points to the host checkout: +For day-to-day dev, start a bind-mounted container so `/root/Relax` inside +the container points to the host checkout: ```bash -chmod +x docker/run-rocm-bind.sh -CONTAINER_NAME=relax_rocm_bind docker/run-rocm-bind.sh +docker/run-rocm-bind.sh # uses relax:rocm-gfx942 +IMAGE=relax:rocm-gfx950 docker/run-rocm-bind.sh docker exec -it relax_rocm_bind bash ``` - -Inside the bind-mounted container, a real-model 2-GPU smoke can be launched -with: - -```bash -cd /root/Relax -CUDA_VISIBLE_DEVICES=6,7 \ -REAL_HF_MODEL_DIR=/mnt/dcgpuval/models/meta-llama/Meta-Llama-3-8B-Instruct \ -NUM_GPUS=2 \ -bash scripts/training/text/run-llama3-8b-2xgpu-debug.sh -``` diff --git a/docker/build-rocm.sh b/docker/build-rocm.sh new file mode 100755 index 00000000..46249c41 --- /dev/null +++ b/docker/build-rocm.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" + +# Build target arch for sgl-kernel. Default gfx942 (MI300X); set +# GPU_ARCH=gfx950 for MI355/MI350. AITER is prebuilt for both archs so +# the resulting image can still run on either host. +GPU_ARCH="${GPU_ARCH:-gfx942}" +IMAGE_TAG="${IMAGE_TAG:-relax:rocm-${GPU_ARCH}}" + +if ! command -v docker >/dev/null 2>&1; then + echo "docker is required but not found in PATH" >&2 + exit 1 +fi + +exec env DOCKER_BUILDKIT="${DOCKER_BUILDKIT:-1}" docker build \ + -f "${SCRIPT_DIR}/Dockerfile.rocm" \ + --target relax \ + --build-arg GPU_ARCH="${GPU_ARCH}" \ + -t "${IMAGE_TAG}" \ + "$@" \ + "${REPO_ROOT}" diff --git a/docker/patch/latest/megatron.patch b/docker/patch/latest/megatron.patch deleted file mode 100644 index 5d4428a5..00000000 --- a/docker/patch/latest/megatron.patch +++ /dev/null @@ -1,1578 +0,0 @@ -diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py -index 41c21d93d..ef80f72d6 100644 ---- a/megatron/core/dist_checkpointing/strategies/common.py -+++ b/megatron/core/dist_checkpointing/strategies/common.py -@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): - msc = MultiStorageClientFeature.import_package() - return msc.torch.load(load_path, map_location='cpu') - else: -- return torch.load(load_path, map_location='cpu') -+ return torch.load(load_path, map_location='cpu', weights_only=False) - except FileNotFoundError as e: - err_msg = f'Common file {load_path} does not exist' - if MultiStorageClientFeature.is_enabled(): -diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py -index 5a1ea308d..aa701237f 100644 ---- a/megatron/core/dist_checkpointing/strategies/torch.py -+++ b/megatron/core/dist_checkpointing/strategies/torch.py -@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): - def _validate_global_shapes(self, metadata, sharded_tensors): - for sh_ten in sharded_tensors: - if sh_ten.key not in metadata.state_dict_metadata: -- raise KeyError( -- f"{sh_ten.key} from model not in state dict:" -- f" {sorted(metadata.state_dict_metadata.keys())}" -- ) -+ # raise KeyError( -+ # f"{sh_ten.key} from model not in state dict:" -+ # f" {sorted(metadata.state_dict_metadata.keys())}" -+ # ) -+ print(f"{sh_ten.key} from model not in state dict, will skip") -+ continue - loaded_shape = metadata.state_dict_metadata[sh_ten.key].size - expected_shape = self._expected_shape(sh_ten) - if loaded_shape != expected_shape: -@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): - tensor_metadata = self.metadata.state_dict_metadata - metadata_with_sizes = [ - (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) -- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() -+ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata - ] - try: - # Temporarily set sizes to expected shapes -@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): - planner=MCoreLoadPlanner( - shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, - allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, -+ allow_partial_load=True, - ), - ) - -diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py -index acb93ef78..d239db4ab 100644 ---- a/megatron/core/extensions/transformer_engine.py -+++ b/megatron/core/extensions/transformer_engine.py -@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): - ) - - for param in self.parameters(): -+ setattr(param, "parallel_mode", parallel_mode) - if is_expert: - # Reduce the gradient on the expert_data_parallel group for expert linear layers - setattr(param, "allreduce", not self.expert_parallel) -@@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): - - - if HAVE_TE and is_te_min_version("1.9.0.dev0"): -+ def ceil_div(x: int, y: int) -> int: -+ return (x + y - 1) // y -+ -+ class _FakeInt4QuantizationSTE(torch.autograd.Function): -+ @staticmethod -+ def forward(ctx, x, group_size): -+ m, n = x.shape -+ block_size_m, block_size_n = 1, group_size -+ -+ -+ m_padded = ceil_div(m, block_size_m) * block_size_m -+ n_padded = ceil_div(n, block_size_n) * block_size_n -+ -+ x_padded = torch.zeros( -+ (m_padded, n_padded), -+ dtype=x.dtype, device=x.device -+ ) -+ x_padded[:m, :n] = x -+ -+ x_view = x_padded.view( -+ m_padded // block_size_m, -+ block_size_m, -+ n_padded // block_size_n, -+ block_size_n -+ ) -+ -+ x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) -+ q_max = 7 -+ x_scale = x_max / q_max -+ -+ x_scale = x_scale.clamp(min=1e-5) -+ -+ x_div = x_view / x_scale -+ x_round = torch.round(x_div) -+ -+ x_q_clamped = x_round.clamp(-q_max, q_max) -+ -+ x_dequant_view = x_q_clamped * x_scale -+ -+ x_dequant_full = x_dequant_view.view_as(x_padded) -+ x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) -+ -+ return x_out -+ -+ @staticmethod -+ def backward(ctx, grad_output): -+ return grad_output, None -+ -+ def fake_int4_quantization_ste(x, group_size): -+ x_out = _FakeInt4QuantizationSTE.apply(x, group_size) -+ -+ if hasattr(x, 'main_grad'): -+ x_out.main_grad = x.main_grad -+ -+ return x_out - - class TEGroupedLinear(te.pytorch.GroupedLinear): - """ -@@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): - _is_first_microbatch = ( - None if self.disable_parameter_transpose_cache else self.is_first_microbatch - ) -+ - out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) - self.is_first_microbatch = False - -@@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): - return out - return out, None - -+ def _get_weight_tensors(self): -+ """Get the weight tensors of the module.""" -+ weight_tensors = super()._get_weight_tensors() -+ -+ if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": -+ group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) -+ -+ weight_tensors = [ -+ fake_int4_quantization_ste(w, group_size) -+ for w in weight_tensors -+ ] -+ -+ return weight_tensors -+ - def _encode_extra_state(self, state): - # TE 2.0 changed the format of extra_state to be a byte tensor - if is_te_min_version("2.0.0"): -diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py -index 1fd5dcfae..c9aeef1f0 100644 ---- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py -+++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py -@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( - SIN, - emb_dim: tl.constexpr, - k_dim: tl.constexpr, -+ k_dim_ceil: tl.constexpr, - v_dim: tl.constexpr, - head_num: tl.constexpr, - batch_size, -@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( - cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) - sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) - -- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads -- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads -- mask = kv_off < head_num * stride_kv_nheads -- k_in_off = kv_off + tl.arange(0, k_dim)[None, :] -- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] -- k = tl.load(KV_ptr + k_in_off, mask=mask) -- v = tl.load(KV_ptr + v_in_off, mask=mask) -+ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads -+ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H -+ kj_range = tl.arange(0, k_dim_ceil)[None, :] -+ mask_k = (ki_range < head_num) & (kj_range < k_dim) -+ mask_v = ki_range < head_num -+ k_off = ki_range * stride_kv_nheads + kj_range -+ if v_dim > 0: -+ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] -+ v = tl.load(KV_ptr + v_off, mask=mask_v) -+ else: -+ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) -+ k = tl.load(KV_ptr + k_off, mask=mask_k) - -- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads -- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads -+ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads -+ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads - -- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] -- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] -- tl.store(K_ptr + k_out_off, k, mask=mask) -- tl.store(V_ptr + v_out_off, v, mask=mask) -+ k_out_off = ki_range * stride_k_nheads + kj_range -+ tl.store(K_ptr + k_out_off, k, mask=mask_k) -+ if v_dim > 0: -+ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] -+ tl.store(V_ptr + v_out_off, v, mask=mask_v) - - EMB = K_POS_EMB + pid_m * stride_emb_seq - # x1 = t[..., 0::2], x2 = t[..., 1::2] -@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( - x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) - x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) - -+ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H -+ mask_x = x_range < head_num - x_left_off = ( -- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads -+ x_range * stride_k_nheads - + k_dim - + tl.arange(0, emb_dim // 2)[None, :] - ) - x_right_off = x_left_off + emb_dim // 2 -- tl.store(K_ptr + x_left_off, x_left, mask=mask) -- tl.store(K_ptr + x_right_off, x_right, mask=mask) -+ tl.store(K_ptr + x_left_off, x_left, mask=mask_x) -+ tl.store(K_ptr + x_right_off, x_right, mask=mask_x) - - - @triton.autotune( -@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( - SIN, - emb_dim: tl.constexpr, - k_dim: tl.constexpr, -+ k_dim_ceil: tl.constexpr, - v_dim: tl.constexpr, - head_num: tl.constexpr, - batch_size, -@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( - else: - token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) - -- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads -- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads -- mask = dkv_off < head_num * stride_dkv_nheads -- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] -- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] -- -- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads -- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads -- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] -- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] -- dk = tl.load(dK_ptr + dk_in_off, mask=mask) -- dv = tl.load(dV_ptr + dv_in_off, mask=mask) -- tl.store(dKV_ptr + dk_out_off, dk, mask=mask) -- tl.store(dKV_ptr + dv_out_off, dv, mask=mask) -+ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads -+ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H -+ kj_range = tl.arange(0, k_dim_ceil)[None, :] -+ mask_k = (ki_range < head_num) & (kj_range < k_dim) -+ mask_v = ki_range < head_num -+ dk_out_off = ki_range * stride_dkv_nheads + kj_range -+ -+ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads -+ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads -+ dk_in_off = ki_range * stride_dk_nheads + kj_range -+ -+ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) -+ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) -+ -+ if v_dim > 0: -+ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] -+ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] -+ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) -+ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) - - if pid_head == 0: - x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) - x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) - for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): -- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads -- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim -+ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads -+ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads - mask = x_off < head_num * stride_dk_nheads - x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] - x_right_off = x_left_off + emb_dim // 2 -@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - - o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) - o_value = kv.new_empty(total_seqlen, nheads, v_dim) -+ k_dim_ceil = triton.next_power_of_2(k_dim) - - grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) - rotary_fwd_kv_kernel[grid]( -@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - sin, - emb_dim, - k_dim, -+ k_dim_ceil, - v_dim, - nheads, - batch_size, -@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - - d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) - d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) -+ k_dim_ceil = triton.next_power_of_2(ctx.k_dim) - - grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) - rotary_bwd_kv_kernel[grid]( -@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - sin, - ctx.emb_dim, - ctx.k_dim, -+ k_dim_ceil, - ctx.v_dim, - nheads, - batch_size, -diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py -index 5d7b69cd3..2e0a26815 100644 ---- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py -+++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py -@@ -348,6 +348,7 @@ class MultimodalRotaryEmbedding(nn.Module): - - # shape (seq_length, bs, 1, 2 * dim) - emb = emb[..., None, :].transpose(0, 1).contiguous() -+ packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: - if packed_seq_params.local_cp_size > 1: - # Set CP group to dynamic CP group for CP slicing -@@ -357,7 +358,9 @@ class MultimodalRotaryEmbedding(nn.Module): - cp_group = None - else: - cp_group = self.cp_group -- if cp_group is not None and cp_group.size() > 1: -+ # For THD (packed sequence) format, skip CP slicing here — it is handled -+ # per-sequence inside _apply_rotary_pos_emb_thd instead (same as RotaryEmbedding). -+ if cp_group is not None and cp_group.size() > 1 and not packed_seq: - # slice rotary_pos_emb along sequence dimension and select the parition of the current - # CP rank - emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) -diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py -index 13d74aa52..060898a7a 100644 ---- a/megatron/core/models/common/language_module/language_module.py -+++ b/megatron/core/models/common/language_module/language_module.py -@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): - assert ( - column_parallel_linear is not None - ), "column_parallel_linear cannot be None when not using fused linear cross entropy." -- logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) -+ # output -+ output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} -+ output_layer_buffers = dict(column_parallel_linear.named_buffers()) -+ logits, _ = torch.func.functional_call( -+ column_parallel_linear, -+ {**output_layer_params, **output_layer_buffers}, -+ (hidden,), -+ col_linear_kwargs, -+ ) - - return self.compute_language_model_loss(labels, logits) - -diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py -index e21127b87..712793853 100755 ---- a/megatron/core/models/gpt/gpt_layer_specs.py -+++ b/megatron/core/models/gpt/gpt_layer_specs.py -@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( - use_kitchen: bool = False, - use_te_activation_func: bool = False, - fallback_to_eager_attn: bool = False, -+ post_self_attn_layernorm: bool = False, -+ post_mlp_layernorm: bool = False, - ) -> ModuleSpec: - """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). - -@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( - mlp=mlp, - sharded_state_dict_keys_map=sharded_state_dict_keys_map, - normalization=normalization, -+ post_self_attn_layernorm=post_self_attn_layernorm, -+ post_mlp_layernorm=post_mlp_layernorm, - ) - - -@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( - mlp: ModuleSpec, - sharded_state_dict_keys_map: Optional[dict] = None, - normalization: Optional[str] = None, -+ post_self_attn_layernorm: bool = False, -+ post_mlp_layernorm: bool = False, - ) -> ModuleSpec: - """Helper function to get module spec for TransformerLayer""" - -@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( - input_layernorm=input_layernorm, - self_attention=attention, - self_attn_bda=get_bias_dropout_add, -+ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, - pre_mlp_layernorm=pre_mlp_layernorm, - mlp=mlp, - mlp_bda=get_bias_dropout_add, -+ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, - sharded_state_dict_keys_map=sharded_state_dict_keys_map, - ), - ) -diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py -index a1230568c..1fd52f65a 100644 ---- a/megatron/core/models/gpt/gpt_model.py -+++ b/megatron/core/models/gpt/gpt_model.py -@@ -446,6 +446,7 @@ class GPTModel(LanguageModule): - *, - inference_params: Optional[BaseInferenceContext] = None, - loss_mask: Optional[Tensor] = None, -+ mtp_kwargs: Optional[dict] = {}, - ) -> Tensor: - """Forward function of the GPT Model This function passes the input tensors - through the embedding layer, and then the decoder and finally into the post -@@ -508,6 +509,7 @@ class GPTModel(LanguageModule): - runtime_gather_output=runtime_gather_output, - extra_block_kwargs=extra_block_kwargs, - inference_context=inference_context, -+ mtp_kwargs=mtp_kwargs, - ) - - def _postprocess( -@@ -529,6 +531,7 @@ class GPTModel(LanguageModule): - runtime_gather_output=None, - extra_block_kwargs=None, - inference_context=None, -+ mtp_kwargs={}, - ): - """Postprocesses decoder hidden states to generate logits or compute loss. - -@@ -543,7 +546,8 @@ class GPTModel(LanguageModule): - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() -- if mtp_in_postprocess: -+ -+ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: - hidden_states = self.mtp( - input_ids=input_ids, - position_ids=position_ids, -@@ -563,13 +567,18 @@ class GPTModel(LanguageModule): - return hidden_states - - # Skip when mtp_num_layers is None or 0 -- if self.config.mtp_num_layers: -- mtp_labels = labels.clone() -+ if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: -+ mtp_labels = mtp_kwargs['mtp_labels'].clone() -+ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) -+ - hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) - hidden_states = hidden_states_list[0] - if loss_mask is None: - # if loss_mask is not provided, use all ones as loss_mask - loss_mask = torch.ones_like(mtp_labels) -+ else: -+ # Otherwise, roll the loss_mask to keep up with the mtp_labels -+ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) - for mtp_layer_number in range(self.config.mtp_num_layers): - # Calc loss for the current Multi-Token Prediction (MTP) layers. - mtp_labels, _ = roll_tensor( -@@ -595,7 +604,7 @@ class GPTModel(LanguageModule): - sequence_parallel_enabled=self.output_layer.sequence_parallel, - column_parallel_linear=self.output_layer, - col_linear_kwargs={ -- 'weight': output_weight, -+ 'weight': output_weight.detach() if output_weight else None, - 'runtime_gather_output': runtime_gather_output, - }, - ) -diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py -index 6e093f96f..eac21a3ea 100644 ---- a/megatron/core/optimizer/distrib_optimizer.py -+++ b/megatron/core/optimizer/distrib_optimizer.py -@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): - # TE FusedAdam will not accumulate step for empty param groups, so we need to - # align the step across param groups. - param_group["step"] = int(step) -+ if "step" in param_group and param_group["step"] is None: -+ del param_group["step"] - - # Grad scaler state. - if self.grad_scaler: -@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): - if key == 'padding': - tensors[key] = LocalNonpersistentObject(tensors[key]) - continue -+ if key == 'step': -+ continue - assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( - tensors[key].shape, - gbuf_local_start, -diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py -index a273002b9..4f821cfd5 100644 ---- a/megatron/core/parallel_state.py -+++ b/megatron/core/parallel_state.py -@@ -11,6 +11,7 @@ from typing import Callable, List, Optional - - import numpy as np - import torch -+import torch.distributed as dist - - from .utils import GlobalMemoryBuffer, is_torch_min_version - -diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py -index ac839c21f..f18309217 100644 ---- a/megatron/core/pipeline_parallel/p2p_communication.py -+++ b/megatron/core/pipeline_parallel/p2p_communication.py -@@ -26,22 +26,22 @@ def _batched_p2p_ops( - ops = [] - if tensor_send_prev is not None: - send_prev_op = torch.distributed.P2POp( -- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group -+ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, - ) - ops.append(send_prev_op) - if tensor_recv_prev is not None: - recv_prev_op = torch.distributed.P2POp( -- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group -+ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, - ) - ops.append(recv_prev_op) - if tensor_send_next is not None: - send_next_op = torch.distributed.P2POp( -- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group -+ torch.distributed.isend, tensor_send_next, next_pipeline_rank, - ) - ops.append(send_next_op) - if tensor_recv_next is not None: - recv_next_op = torch.distributed.P2POp( -- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group -+ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, - ) - ops.append(recv_next_op) - if len(ops) > 0: -diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py -index 28cff06f5..58dc4bb70 100644 ---- a/megatron/core/transformer/moe/moe_utils.py -+++ b/megatron/core/transformer/moe/moe_utils.py -@@ -587,6 +587,9 @@ def topk_routing_with_score_function( - else: - return torch.topk(scores, k=topk, dim=1) - -+ from relax.utils.training.routing_replay import get_routing_replay_compute_topk -+ compute_topk = get_routing_replay_compute_topk(compute_topk) -+ - if score_function == "softmax": - if use_pre_softmax: - scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) -diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py -index 16fc9d9af..517944f25 100644 ---- a/megatron/core/transformer/moe/router.py -+++ b/megatron/core/transformer/moe/router.py -@@ -201,6 +201,9 @@ class TopKRouter(Router): - self.global_tokens_per_expert = None - self.ga_steps = None - -+ from relax.utils.training.routing_replay import register_routing_replay -+ register_routing_replay(self) -+ - def _maintain_float32_expert_bias(self): - """ - Maintain the expert bias in float32. -diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py -index a8f4abfcd..f33f6f05e 100755 ---- a/megatron/core/transformer/multi_token_prediction.py -+++ b/megatron/core/transformer/multi_token_prediction.py -@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union - - import torch - from torch import Tensor -+import warnings - - from megatron.core import InferenceParams, parallel_state, tensor_parallel - from megatron.core.dist_checkpointing.mapping import ShardedStateDict -@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) -- position_ids, _ = roll_tensor( -- position_ids, -- shifts=-1, -- dims=-1, -- cp_group=self.cp_group, -- packed_seq_params=packed_seq_params, -- ) -+ if position_ids is not None: -+ position_ids, _ = roll_tensor( -+ position_ids, -+ shifts=-1, -+ dims=-1, -+ cp_group=self.cp_group, -+ packed_seq_params=packed_seq_params, -+ ) - # embedding - decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) -+ decoder_input = decoder_input.detach() - -- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) -+ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) - - return input_ids, position_ids, decoder_input, hidden_states - -@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): - return hidden_states - - def _checkpointed_forward(self, forward_func, *args, **kwargs): -+ """Wrap `forward_func` with activation checkpointing while only passing tensors. -+ -+ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so -+ that checkpoint implementations never receive them directly, avoiding save_for_backward -+ issues with non-tensor inputs. -+ """ -+ -+ # TODO(jiajun): Is there any better implementation here? -+ positional_specs = [] -+ kw_specs = [] -+ tensor_args: List[torch.Tensor] = [] -+ -+ for arg in args: -+ if torch.is_tensor(arg): -+ positional_specs.append(('tensor', len(tensor_args))) -+ tensor_args.append(arg) -+ else: -+ positional_specs.append(('const', arg)) -+ -+ for key, value in kwargs.items(): -+ if torch.is_tensor(value): -+ kw_specs.append((key, ('tensor', len(tensor_args)))) -+ tensor_args.append(value) -+ else: -+ kw_specs.append((key, ('const', value))) -+ -+ def run(*flat_tensor_args): -+ rebuilt_args = [] -+ for spec_type, payload in positional_specs: -+ if spec_type == 'tensor': -+ rebuilt_args.append(flat_tensor_args[payload]) -+ else: -+ rebuilt_args.append(payload) -+ -+ rebuilt_kwargs = {} -+ for key, (spec_type, payload) in kw_specs: -+ if spec_type == 'tensor': -+ rebuilt_kwargs[key] = flat_tensor_args[payload] -+ else: -+ rebuilt_kwargs[key] = payload -+ -+ return forward_func(*rebuilt_args, **rebuilt_kwargs) -+ -+ tensor_args_tuple = tuple(tensor_args) -+ - def checkpoint_handler(): - """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" - if self.config.fp8: -@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): - self.config.distribute_saved_activations, - tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), -- *args, -- **kwargs, -+ *tensor_args_tuple, - ) - else: - return tensor_parallel.checkpoint( -- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() -+ run, self.config.distribute_saved_activations, *tensor_args_tuple - ) - - if self.config.recompute_method == 'uniform': -diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py -index e2705bd9f..a0aa109b5 100644 ---- a/megatron/core/transformer/transformer_config.py -+++ b/megatron/core/transformer/transformer_config.py -@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): - attention_output_gate: bool = False - """Whether to apply output gate to the attention layers.""" - -+ post_self_attn_layernorm: bool = False -+ post_mlp_layernorm: bool = False -+ - test_mode: bool = False - """Whether to run real-time tests.""" - -diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py -index 3ea405770..5a42001b9 100644 ---- a/megatron/core/transformer/transformer_layer.py -+++ b/megatron/core/transformer/transformer_layer.py -@@ -223,6 +223,7 @@ class TransformerLayerSubmodules: - input_layernorm: Union[ModuleSpec, type] = IdentityOp - self_attention: Union[ModuleSpec, type] = IdentityOp - self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp -+ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp - - pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp - cross_attention: Union[ModuleSpec, type] = IdentityOp -@@ -231,6 +232,7 @@ class TransformerLayerSubmodules: - pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp - mlp: Union[ModuleSpec, type] = IdentityOp - mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp -+ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp - - # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method - sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) -@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - # [Module 3: BiasDropoutFusion] - self.self_attn_bda = build_module(submodules.self_attn_bda) - -+ self.post_self_attn_layernorm = build_module( -+ submodules.post_self_attn_layernorm, -+ config=self.config, -+ hidden_size=self.config.hidden_size, -+ eps=self.config.layernorm_epsilon, -+ ) -+ - # [Module 4: Post SelfAttention] Optional Layernorm after self-attn - self.pre_cross_attn_layernorm = build_module( - submodules.pre_cross_attn_layernorm, -@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - - self.is_moe_layer = isinstance(self.mlp, MoELayer) - -+ self.post_mlp_layernorm = build_module( -+ submodules.post_mlp_layernorm, -+ config=self.config, -+ hidden_size=self.config.hidden_size, -+ eps=self.config.layernorm_epsilon -+ ) -+ - self.recompute_input_layernorm = False - self.recompute_pre_mlp_layernorm = False - self.recompute_mlp = False -@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - attention_output_with_bias[0] - ) - -+ attention_output, attention_output_bias = attention_output_with_bias -+ attention_output = self.post_self_attn_layernorm(attention_output) -+ attention_output_with_bias = (attention_output, attention_output_bias) -+ - # TODO: could we move `bias_dropout_add_exec_handler` itself - # inside the module provided in the `bias_dropout_add_spec` module? - nvtx_range_push(suffix="self_attn_bda") -@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - else: - mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) - -+ mlp_output, mlp_output_bias = mlp_output_with_bias -+ mlp_output = self.post_mlp_layernorm(mlp_output) -+ mlp_output_with_bias = (mlp_output, mlp_output_bias) -+ - if self.recompute_pre_mlp_layernorm: - # discard the output of the pre-mlp layernorm and register the recompute - # as a gradient hook of mlp_output_with_bias[0] -diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index b267c8a81..83736acdc 100644 ---- a/megatron/training/arguments.py -+++ b/megatron/training/arguments.py -@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): - - kw_args['inference_sampling_seed'] = args.seed - -+ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm -+ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm -+ - # handle quantization config - # NOTE: Kitchen arguments are only added to the namespace when - # Kitchen library is available. -@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): - action='store_true', - help='If set, use original BERT residula connection ' - 'ordering.') -+ group.add_argument('--post-self-attn-layernorm', action='store_true', -+ help='If set, use post self attention layernorm.') -+ group.add_argument('--post-mlp-layernorm', action='store_true', -+ help='If set, use post MLP layernorm.') -+ group.add_argument('--use-gated-attention', action='store_true', -+ help='If set, use gated attention as in Qwen3Next') - group.add_argument('--openai-gelu', action='store_true', - help='Use OpenAIs GeLU implementation. This option' - 'should not be used unless for backward compatibility' -diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py -index 13b7526ca..6c590f653 100644 ---- a/megatron/training/tokenizer/tokenizer.py -+++ b/megatron/training/tokenizer/tokenizer.py -@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): - # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there - self._tokenizer = transformers.AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=pretrained_model_name_or_path, -- trust_remote_code=trust_remote_code, -+ trust_remote_code=True, - **kwargs, - ) - self._vocab = self._tokenizer.get_vocab() -diff --git a/megatron/core/ssm/gated_delta_net.py b/megatron/core/ssm/gated_delta_net.py -index dfa6e4c35..0b38f1135 100644 ---- a/megatron/core/ssm/gated_delta_net.py -+++ b/megatron/core/ssm/gated_delta_net.py -@@ -21,6 +21,12 @@ from megatron.core.inference.contexts import BaseInferenceContext - from megatron.core.jit import jit_fuser - from megatron.core.packed_seq_params import PackedSeqParams - from megatron.core.process_groups_config import ProcessGroupCollection -+from megatron.core.ssm.mamba_context_parallel import ( -+ _all_to_all_cp2hp, -+ _all_to_all_hp2cp, -+ _redo_attention_load_balancing, -+ _undo_attention_load_balancing, -+) - from megatron.core.tensor_parallel import get_cuda_rng_tracker - from megatron.core.transformer import TransformerConfig - from megatron.core.transformer.identity_op import IdentityOp -@@ -33,24 +39,19 @@ from megatron.core.transformer.utils import ( - ) - from megatron.core.utils import deprecate_inference_params, nvtx_range_pop, nvtx_range_push - --# TODO: Implement GatedDeltaNetContextParallel --# from .gated_delta_net_context_parallel import GatedDeltaNetContextParallel -- - try: -+ from fla.modules.convolution import causal_conv1d - from fla.modules.l2norm import l2norm - from fla.ops.gated_delta_rule import chunk_gated_delta_rule - - HAVE_FLA = True - except ImportError: -+ causal_conv1d = None -+ l2norm = None - chunk_gated_delta_rule = None - - HAVE_FLA = False - --try: -- from causal_conv1d import causal_conv1d_fn --except ImportError: -- causal_conv1d_fn = None -- - - logger = logging.getLogger(__name__) - -@@ -84,6 +85,7 @@ class GatedDeltaNet(MegatronModule): - use_qk_l2norm: bool = True, - A_init_range: Tuple[float, float] = (1, 16), - pg_collection: ProcessGroupCollection = None, -+ **kwargs, - ): - """ - Args: -@@ -98,9 +100,11 @@ class GatedDeltaNet(MegatronModule): - pg_collection: The required process groups to use for tensor model parallel and context - parallel. - """ -- -+ # print(f"new gdn", flush=True) - if not HAVE_FLA: -- raise ImportError("FLA is not installed. Please install it with `pip install fla`.") -+ raise ImportError( -+ "FLA is not installed. Please install it with `pip install flash-linear-attention`." -+ ) - - super().__init__(config) - -@@ -114,6 +118,7 @@ class GatedDeltaNet(MegatronModule): - self.use_qk_l2norm = use_qk_l2norm - assert pg_collection is not None, "pg_collection must be provided for GatedDeltaNet" - self.pg_collection = pg_collection -+ self.cp_size = self.pg_collection.cp.size() - self.tp_size = self.pg_collection.tp.size() - self.sp_size = self.tp_size if config.sequence_parallel else 1 - -@@ -129,6 +134,8 @@ class GatedDeltaNet(MegatronModule): - self.num_value_heads = config.linear_num_value_heads - self.qk_dim = self.key_head_dim * self.num_key_heads - self.v_dim = self.value_head_dim * self.num_value_heads -+ self.qk_dim_local_tp = self.qk_dim // self.tp_size -+ self.v_dim_local_tp = self.v_dim // self.tp_size - - # Input projection (hidden_states -> q, k, v, gate, beta, alpha) - # TODO: for now, output gate is forced for GDN. -@@ -171,8 +178,10 @@ class GatedDeltaNet(MegatronModule): - dtype=config.params_dtype, - ) - setattr(self.conv1d.weight, "tensor_model_parallel", True) -+ setattr(self.conv1d.weight, "partition_dim", 0) - if conv_bias: - setattr(self.conv1d.bias, "tensor_model_parallel", True) -+ setattr(self.conv1d.bias, "partition_dim", 0) - - # Time step projection (discretization) - self.num_v_heads_local_tp = self.num_value_heads // self.tp_size -@@ -185,6 +194,7 @@ class GatedDeltaNet(MegatronModule): - ) - ) - setattr(self.dt_bias, "tensor_model_parallel", True) -+ setattr(self.dt_bias, "partition_dim", 0) - # A_log parameter - self.A_log = nn.Parameter( - torch.empty( -@@ -194,6 +204,12 @@ class GatedDeltaNet(MegatronModule): - ) - ) - setattr(self.A_log, "tensor_model_parallel", True) -+ setattr(self.A_log, "partition_dim", 0) -+ -+ if self.config.deterministic_mode: -+ self.gated_delta_rule = torch_chunk_gated_delta_rule -+ else: -+ self.gated_delta_rule = chunk_gated_delta_rule - - # Output layernorm before projection - self.out_norm = build_module( -@@ -217,8 +233,6 @@ class GatedDeltaNet(MegatronModule): - tp_group=self.pg_collection.tp, - ) - -- # TODO: support CP -- - self.reset_parameters() - - def reset_parameters(self): -@@ -241,23 +255,18 @@ class GatedDeltaNet(MegatronModule): - dtype=self.config.params_dtype, - device=torch.cuda.current_device(), - ).uniform_(*self.A_init_range) -- self.A_log.data.copy_(A) -+ self.A_log.data.copy_(torch.log(A)) - - def forward( - self, - hidden_states: Tensor, - attention_mask: Tensor, -- key_value_states: Optional[Tensor] = None, - inference_context: Optional[BaseInferenceContext] = None, -- rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, -- rotary_pos_cos: Optional[Tensor] = None, -- rotary_pos_sin: Optional[Tensor] = None, -- rotary_pos_cos_sin: Optional[Tensor] = None, -- attention_bias: Optional[Tensor] = None, - packed_seq_params: Optional[PackedSeqParams] = None, - sequence_len_offset: Optional[int] = None, - *, - inference_params: Optional[BaseInferenceContext] = None, -+ **kwargs, - ): - """ - Perform a forward pass through the GDN module. -@@ -265,15 +274,8 @@ class GatedDeltaNet(MegatronModule): - Args: - hidden_states (Tensor): Hidden states. - attention_mask (Tensor): Attention mask. -- key_value_states (Optional[Tensor]): Key/value states (for cross attention). - inference_context (Optional[BaseInferenceContext]): Inference context that manages - KV cache. -- rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary -- embedding tensor(s). -- rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. -- rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. -- rotary_pos_cos_sin (Optional[Tensor]): Combined rotary embedding cosine and sine. -- attention_bias (Optional[Tensor]): Attention bias. - packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format. - sequence_len_offset (Optional[int]): Sequence length offset used for - inference CUDA graphs. -@@ -287,7 +289,7 @@ class GatedDeltaNet(MegatronModule): - inference_context = deprecate_inference_params(inference_context, inference_params) - - seq_len, batch, _ = hidden_states.shape -- seq_len = seq_len * self.sp_size -+ seq_len = seq_len * self.sp_size * self.cp_size - - if inference_context is not None: - assert ( -@@ -297,15 +299,80 @@ class GatedDeltaNet(MegatronModule): - # TODO: support inference - raise NotImplementedError("GDN does not support inference for now.") - -- if packed_seq_params is not None: -- # TODO: support packed sequence -- raise NotImplementedError("GDN does not support packed sequence for now.") -+ if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': -+ assert batch == 1, "Packed sequence expects batch dimension to be 1" -+ assert ( -+ not self.config.deterministic_mode -+ ), "Packed sequence does not support deterministic mode." -+ -+ # Resolve cu_seqlens with alignment padding handling. -+ cu_seqlens_q = self._resolve_cu_seqlens( -+ packed_seq_params.cu_seqlens_q_padded, -+ packed_seq_params.cu_seqlens_q, -+ seq_len, -+ "cu_seqlens_q", -+ ) -+ cu_seqlens_kv = self._resolve_cu_seqlens( -+ packed_seq_params.cu_seqlens_kv_padded, -+ packed_seq_params.cu_seqlens_kv, -+ seq_len, -+ "cu_seqlens_kv", -+ ) -+ assert torch.equal(cu_seqlens_q, cu_seqlens_kv), ( -+ "Currently only support cu_seqlens_q equals to cu_seqlens_kv, " -+ f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}" -+ ) -+ num_packed_seqs = cu_seqlens_q.shape[0] - 1 -+ assert num_packed_seqs > 0, ( -+ "Number of packed sequences must be greater than 0, " -+ f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}" -+ ) -+ else: -+ cu_seqlens_q = None -+ cu_seqlens_kv = None - - # Input projection - nvtx_range_push(suffix="in_proj") - qkvzba, _ = self.in_proj(hidden_states) - nvtx_range_pop(suffix="in_proj") - -+ # CP All to All: CP to HP -+ if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': -+ unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens_q // self.cp_size, dim=0) -+ outputs = [] -+ for qkvzba_i in unpacked_qkvzba: -+ qkvzba_i = tensor_a2a_cp2hp( -+ qkvzba_i, -+ seq_dim=0, -+ head_dim=-1, -+ cp_group=self.pg_collection.cp, -+ split_sections=[ -+ self.qk_dim_local_tp, -+ self.qk_dim_local_tp, -+ self.v_dim_local_tp, -+ self.v_dim_local_tp, -+ self.num_value_heads // self.tp_size, -+ self.num_value_heads // self.tp_size, -+ ], -+ ) -+ outputs.append(qkvzba_i) -+ qkvzba = torch.cat(outputs, dim=0) -+ else: -+ qkvzba = tensor_a2a_cp2hp( -+ qkvzba, -+ seq_dim=0, -+ head_dim=-1, -+ cp_group=self.pg_collection.cp, -+ split_sections=[ -+ self.qk_dim_local_tp, -+ self.qk_dim_local_tp, -+ self.v_dim_local_tp, -+ self.v_dim_local_tp, -+ self.num_value_heads // self.tp_size, -+ self.num_value_heads // self.tp_size, -+ ], -+ ) -+ - # Transpose: s b x --> b s x - # From sbhd to bshd format - qkvzba = qkvzba.transpose(0, 1) -@@ -314,10 +381,10 @@ class GatedDeltaNet(MegatronModule): - qkv, gate, beta, alpha = torch.split( - qkvzba, - [ -- (self.qk_dim * 2 + self.v_dim) // self.tp_size, -- self.v_dim // self.tp_size, -- self.num_value_heads // self.tp_size, -- self.num_value_heads // self.tp_size, -+ (self.qk_dim_local_tp * 2 + self.v_dim_local_tp) // self.cp_size, -+ self.v_dim_local_tp // self.cp_size, -+ self.num_value_heads // self.tp_size // self.cp_size, -+ self.num_value_heads // self.tp_size // self.cp_size, - ], - dim=-1, - ) -@@ -326,74 +393,83 @@ class GatedDeltaNet(MegatronModule): - alpha = alpha.reshape(batch, seq_len, -1) - - # Convolution on qkv -- qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s - nvtx_range_push(suffix="conv1d") -- if (causal_conv1d_fn is None) or self.config.deterministic_mode: -- qkv = self.act_fn(self.conv1d(qkv)[..., :seq_len]) -+ seq_len = qkv.shape[1] -+ qkv_channels_split_sections = [ -+ self.qk_dim_local_tp, -+ self.qk_dim_local_tp, -+ self.v_dim_local_tp, -+ ] -+ conv1d_weight = get_parameter_local_cp( -+ self.conv1d.weight, -+ dim=0, -+ cp_group=self.pg_collection.cp, -+ split_sections=qkv_channels_split_sections, -+ ) -+ conv1d_bias = ( -+ get_parameter_local_cp( -+ self.conv1d.bias, -+ dim=0, -+ cp_group=self.pg_collection.cp, -+ split_sections=qkv_channels_split_sections, -+ ) -+ if self.conv_bias -+ else None -+ ) -+ if self.config.deterministic_mode: -+ qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s -+ conv_out = F.conv1d( -+ input=qkv, # Torch-native only accept [b, d, s] format input -+ weight=conv1d_weight, -+ bias=conv1d_bias, -+ stride=self.conv1d.stride, -+ padding=self.conv1d.padding, -+ dilation=self.conv1d.dilation, -+ groups=self.conv_dim_local_tp // self.cp_size, -+ ) -+ qkv = self.act_fn(conv_out[..., :seq_len]) -+ qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d - else: - assert self.activation in ["silu", "swish"] -- qkv = causal_conv1d_fn( -- x=qkv, -- weight=self.conv1d.weight.squeeze(1), # d, 1, w -> d, w -- bias=self.conv1d.bias, -+ qkv, _ = causal_conv1d( -+ x=qkv, # FLA conv1d accepts [b, s, d] format input -+ weight=conv1d_weight.squeeze(1), # d, 1, w -> d, w -+ bias=conv1d_bias, - activation=self.activation, -+ initial_state=None, -+ output_final_state=False, -+ cu_seqlens=cu_seqlens_q, - ) - nvtx_range_pop(suffix="conv1d") -- # Split qkv into query, key, and value -- qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d -- query, key, value = torch.split( -- qkv, -- [self.qk_dim // self.tp_size, self.qk_dim // self.tp_size, self.v_dim // self.tp_size], -- dim=-1, -- ) -- query = query.reshape(batch, seq_len, -1, self.key_head_dim) -- key = key.reshape(batch, seq_len, -1, self.key_head_dim) -- value = value.reshape(batch, seq_len, -1, self.value_head_dim) -- # Apply L2 norm to query and key -- if self.use_qk_l2norm: -- query = l2norm(query.contiguous()) -- key = l2norm(key.contiguous()) -- if self.num_value_heads // self.num_key_heads > 1: -- query = query.repeat_interleave(self.num_value_heads // self.num_key_heads, dim=2) -- key = key.repeat_interleave(self.num_value_heads // self.num_key_heads, dim=2) - -- # Make contiguous -- query = query.contiguous() -- key = key.contiguous() -- value = value.contiguous() -- gate = gate.contiguous() -- beta = beta.contiguous() -- alpha = alpha.contiguous() -+ # Prepare QKV tensors (split, reshape, L2 norm, repeat_interleave, contiguous) -+ nvtx_range_push(suffix="prepare_qkv_for_gated_delta_rule") -+ query, key, value, gate, beta, alpha = self._prepare_qkv_for_gated_delta_rule( -+ qkv, gate, beta, alpha, batch, seq_len -+ ) -+ nvtx_range_pop(suffix="prepare_qkv_for_gated_delta_rule") - - # Calculate g and beta - nvtx_range_push(suffix="g_and_beta") -- g = -self.A_log.exp() * F.softplus(alpha.float() + self.dt_bias) # In fp32 -- beta = beta.sigmoid() -+ A_log_local_cp = get_parameter_local_cp(self.A_log, dim=0, cp_group=self.pg_collection.cp) -+ dt_bias_local_cp = get_parameter_local_cp( -+ self.dt_bias, dim=0, cp_group=self.pg_collection.cp -+ ) -+ g, beta = self._compute_g_and_beta(A_log_local_cp, dt_bias_local_cp, alpha, beta) - nvtx_range_pop(suffix="g_and_beta") - - nvtx_range_push(suffix="gated_delta_rule") -- if self.config.deterministic_mode: -- core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule( -- query, -- key, -- value, -- g=g, -- beta=beta, -- initial_state=None, -- output_final_state=False, -- use_qk_l2norm_in_kernel=False, -- ) -- else: -- core_attn_out, last_recurrent_state = chunk_gated_delta_rule( -- query, -- key, -- value, -- g=g, -- beta=beta, -- initial_state=None, -- output_final_state=False, -- use_qk_l2norm_in_kernel=False, -- ) -+ core_attn_out, last_recurrent_state = self.gated_delta_rule( -+ query, -+ key, -+ value, -+ g=g, -+ beta=beta, -+ initial_state=None, -+ output_final_state=False, -+ use_qk_l2norm_in_kernel=False, -+ cu_seqlens=cu_seqlens_q, -+ ) - nvtx_range_pop(suffix="gated_delta_rule") - - # RMSNorm -@@ -406,6 +482,21 @@ class GatedDeltaNet(MegatronModule): - norm_out = norm_out.reshape(batch, seq_len, -1) - norm_out = norm_out.transpose(0, 1).contiguous() - -+ # CP all to all: HP to CP -+ if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': -+ unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens_q, dim=0) -+ outputs = [] -+ for norm_out_i in unpacked_norm_out: -+ norm_out_i = tensor_a2a_hp2cp( -+ norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp -+ ) -+ outputs.append(norm_out_i) -+ norm_out = torch.cat(outputs, dim=0) -+ else: -+ norm_out = tensor_a2a_hp2cp( -+ norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp -+ ) -+ - # Output projection - nvtx_range_push(suffix="out_proj") - out, out_bias = self.out_proj(norm_out) -@@ -425,6 +516,74 @@ class GatedDeltaNet(MegatronModule): - y = y.to(x_dtype) - return y - -+ @jit_fuser -+ def _prepare_qkv_for_gated_delta_rule(self, qkv, gate, beta, alpha, batch, seq_len): -+ """ -+ Prepare query, key, value, gate, beta, alpha tensors for gated delta rule. -+ Fuses split, reshape, L2 norm, repeat_interleave, and contiguous operations. -+ """ -+ # Split qkv into query_key and value -+ query_key, value = torch.split( -+ qkv, -+ [2 * self.qk_dim_local_tp // self.cp_size, self.v_dim_local_tp // self.cp_size], -+ dim=-1, -+ ) -+ -+ # Reshape query_key and value -+ query_key = query_key.reshape(batch, seq_len, -1, self.key_head_dim) -+ value = value.reshape(batch, seq_len, -1, self.value_head_dim) -+ -+ # Apply L2 norm to query and key -+ if self.use_qk_l2norm: -+ query_key = l2norm(query_key.contiguous()) -+ -+ # Split query and key -+ split_size = self.qk_dim_local_tp // self.key_head_dim // self.cp_size -+ query, key = torch.split(query_key, [split_size, split_size], dim=2) -+ -+ # Expand query and key if needed (grouped query attention) -+ if self.num_value_heads // self.num_key_heads > 1: -+ repeat_factor = self.num_value_heads // self.num_key_heads -+ query = query.repeat_interleave(repeat_factor, dim=2) -+ key = key.repeat_interleave(repeat_factor, dim=2) -+ -+ # Make all tensors contiguous -+ query = query.contiguous() -+ key = key.contiguous() -+ value = value.contiguous() -+ gate = gate.contiguous() -+ beta = beta.contiguous() -+ alpha = alpha.contiguous() -+ -+ return query, key, value, gate, beta, alpha -+ -+ @jit_fuser -+ def _compute_g_and_beta(self, A_log_local_cp, dt_bias_local_cp, alpha, beta): -+ """ -+ Compute g (decay) and beta (sigmoid) for gated delta rule. -+ Fuses exp, softplus, mul, neg, and sigmoid operations. -+ """ -+ g = -A_log_local_cp.exp() * F.softplus(alpha.float() + dt_bias_local_cp) # In fp32 -+ beta = beta.sigmoid() -+ return g, beta -+ -+ def _resolve_cu_seqlens(self, cu_seqlens_padded, cu_seqlens_actual, total_seq_len, name): -+ """Resolve cu_seqlens for packed sequence all-to-all, handling alignment padding.""" -+ if cu_seqlens_padded is not None: -+ cu_seqlens = cu_seqlens_padded -+ else: -+ cu_seqlens = cu_seqlens_actual -+ -+ total_cu = cu_seqlens[-1].item() -+ if total_cu != total_seq_len: -+ raise ValueError( -+ f"GDN: {name}[-1]={total_cu} does not match " -+ f"total_sequence_length={total_seq_len}. " -+ f"({cu_seqlens_padded=}, {cu_seqlens_actual=})." -+ ) -+ -+ return cu_seqlens -+ - def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_group=None): - """Provide a sharded state dictionary for distributed checkpointing.""" - # Guard for cases metadata is not provided -@@ -479,10 +638,10 @@ class GatedDeltaNet(MegatronModule): - sharded_state_dict[f"{prefix}in_proj.weight"] = _split_tensor_factory( - sharded_state_dict[f"{prefix}in_proj.weight"], - [ -- self.qk_dim // self.tp_size, -- self.qk_dim // self.tp_size, -- self.v_dim // self.tp_size, -- self.v_dim // self.tp_size, -+ self.qk_dim_local_tp, -+ self.qk_dim_local_tp, -+ self.v_dim_local_tp, -+ self.v_dim_local_tp, - self.num_value_heads // self.tp_size, - self.num_value_heads // self.tp_size, - ], -@@ -502,18 +661,41 @@ class GatedDeltaNet(MegatronModule): - for conv_layer_name in conv_layer_name_list: - sharded_state_dict[f"{prefix}{conv_layer_name}"] = _split_tensor_factory( - sharded_state_dict[f"{prefix}{conv_layer_name}"], -- [ -- self.qk_dim // self.tp_size, -- self.qk_dim // self.tp_size, -- self.v_dim // self.tp_size, -- ], -+ [self.qk_dim_local_tp, self.qk_dim_local_tp, self.v_dim_local_tp], - ["query", "key", "value"], - 0, - ) - - return sharded_state_dict - -+ def backward_dw(self): -+ """Execute weight gradient computation for all linear layers.""" -+ self._backward_in_proj() -+ self._backward_out_proj() -+ -+ def _backward_in_proj(self): -+ """Computes weight gradients of input projection layer.""" -+ self.in_proj.backward_dw() -+ -+ def _backward_out_proj(self): -+ """Computes weight gradients of output projection layer.""" -+ self.out_proj.backward_dw() -+ -+ -+def _unpack_sequence(x, cu_seqlens, dim=1): -+ unpacked_x = [] -+ num_seqs = cu_seqlens.shape[0] - 1 -+ for i in range(num_seqs): -+ idx_start = cu_seqlens[i].item() -+ idx_end = cu_seqlens[i + 1].item() -+ chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)] -+ unpacked_x.append(x[tuple(chunked_index)]) -+ return unpacked_x -+ - -+#################### -+# Sharded state dict utilities -+#################### - def _split_tensor_factory( - orig_sh_ten: ShardedTensor, split_sections: List[int], split_names: List[str], split_dim: int - ) -> ShardedTensorFactory: -@@ -574,6 +756,184 @@ def _split_tensor_factory( - ) - - -+#################### -+# Context parallel utilities -+#################### -+def get_parameter_local_cp( -+ param: torch.Tensor, -+ dim: int, -+ cp_group: torch.distributed.ProcessGroup, -+ split_sections: Optional[List[int]] = None, -+) -> torch.Tensor: -+ """Get the local parameter for the current context parallel rank. -+ -+ Args: -+ param (torch.Tensor): The entire parameter to get the local parameter for. -+ dim (int): The dimension to split the parameter along. Usually the dimension of head. -+ cp_group (torch.distributed.ProcessGroup): The context parallel group. -+ split_sections (Optional[List[int]]): If not None, -+ first split the parameter along the dimension dim into sections, -+ then get the local hidden parallel weights separately, -+ finally concatenate the local hidden parallel weights along the dimension dim. -+ -+ Returns: -+ torch.Tensor: The local parameter for the current context parallel rank. -+ """ -+ -+ cp_size = cp_group.size() -+ cp_rank = cp_group.rank() -+ -+ # No need to split if CP size is 1. -+ if cp_size == 1: -+ return param -+ -+ # Split first if needed. -+ if split_sections is not None: -+ inputs = torch.split(param, split_sections, dim=dim) -+ outputs = [] -+ for p in inputs: -+ p = get_parameter_local_cp(p, dim, cp_group) -+ outputs.append(p) -+ return torch.cat(outputs, dim=dim) -+ -+ # Slice the parameter. -+ slices = [slice(None)] * param.dim() -+ dim_size = param.size(dim=dim) -+ slices[dim] = slice(cp_rank * dim_size // cp_size, (cp_rank + 1) * dim_size // cp_size) -+ param = param[slices] -+ return param -+ -+ -+def tensor_a2a_cp2hp( -+ tensor: torch.Tensor, -+ seq_dim: int, -+ head_dim: int, -+ cp_group: torch.distributed.ProcessGroup, -+ split_sections: Optional[List[int]] = None, -+ undo_attention_load_balancing: bool = True, -+): -+ """All-to-all context parallel to hidden parallel. -+ -+ Args: -+ tensor (torch.Tensor): The tensor to all-to-all. -+ Currently only support (seq_len, batch, head_dim) shaped tensor. -+ seq_dim (int): The dimension of sequence length. Currently only supports seq_dim == 0. -+ head_dim (int): The dimension of head. Currently only supports head_dim == -1 or 2. -+ cp_group (torch.distributed.ProcessGroup): The context parallel group. -+ split_sections (Optional[List[int]]): If not None, split the tensor along the dimension -+ head_dim into sections first, then do all-to-all for each section separately, -+ finally concatenate the separated tensors along the dimension head_dim. -+ undo_attention_load_balancing (bool): Whether to undo the attention load balancing of CP. -+ -+ Returns: -+ torch.Tensor: The all-to-all tensor. -+ """ -+ -+ cp_size = cp_group.size() -+ -+ # No need to all-to-all if CP size is 1. -+ if cp_size == 1: -+ return tensor -+ -+ # Limitations of mamba_context_parallel._all_to_all_cp2hp. -+ assert seq_dim == 0, f"tensor_a2a_cp2hp only supports seq_dim == 0 for now, but got {seq_dim=}" -+ assert ( -+ head_dim == -1 or head_dim == 2 -+ ), f"tensor_a2a_cp2hp only supports head_dim == -1 or 2 for now, but got {head_dim=}" -+ assert ( -+ tensor.dim() == 3 -+ ), f"tensor_a2a_cp2hp only supports 3-d input tensor for now, but got {tensor.dim()=}" -+ -+ # Split first if needed. -+ if split_sections is not None: -+ inputs = torch.split(tensor, split_sections, dim=head_dim) -+ outputs = [] -+ for x in inputs: -+ x = tensor_a2a_cp2hp( -+ x, -+ seq_dim=seq_dim, -+ head_dim=head_dim, -+ cp_group=cp_group, -+ undo_attention_load_balancing=False, -+ ) -+ outputs.append(x) -+ tensor = torch.cat(outputs, dim=head_dim) -+ else: -+ tensor = _all_to_all_cp2hp(tensor, cp_group) -+ -+ # Undo attention load balancing last if needed. -+ if undo_attention_load_balancing: -+ tensor = _undo_attention_load_balancing(tensor, cp_size) -+ return tensor -+ -+ -+def tensor_a2a_hp2cp( -+ tensor: torch.Tensor, -+ seq_dim: int, -+ head_dim: int, -+ cp_group: torch.distributed.ProcessGroup, -+ split_sections: Optional[List[int]] = None, -+ redo_attention_load_balancing: bool = True, -+): -+ """All-to-all hidden parallel to context parallel. -+ -+ Args: -+ tensor (torch.Tensor): The tensor to all-to-all. -+ Currently only support (seq_len, batch, head_dim) shaped tensor. -+ seq_dim (int): The dimension of sequence length. Currently only supports seq_dim == 0. -+ head_dim (int): The dimension of head. Currently only supports head_dim == -1 or 2. -+ cp_group (torch.distributed.ProcessGroup): The context parallel group. -+ split_sections (Optional[List[int]]): If not None, first split the tensor along the -+ dimension head_dim into sections, then do all-to-all for each section separately, -+ finally concatenate the separated tensors along the dimension head_dim. -+ redo_attention_load_balancing (bool): Whether to redo the attention load balancing of HP. -+ -+ Returns: -+ torch.Tensor: The all-to-all tensor. -+ """ -+ -+ cp_size = cp_group.size() -+ -+ # No need to all-to-all if CP size is 1. -+ if cp_size == 1: -+ return tensor -+ -+ # Limitations of mamba_context_parallel._all_to_all_hp2cp. -+ assert seq_dim == 0, f"tensor_a2a_cp2hp only supports seq_dim == 0 for now, but got {seq_dim=}" -+ assert ( -+ head_dim == -1 or head_dim == 2 -+ ), f"tensor_a2a_cp2hp only supports head_dim == -1 or 2 for now, but got {head_dim=}" -+ assert ( -+ tensor.dim() == 3 -+ ), f"tensor_a2a_cp2hp only supports 3-d input tensor for now, but got {tensor.dim()=}" -+ -+ # Redo attention load balancing first if needed. -+ if redo_attention_load_balancing: -+ tensor = _redo_attention_load_balancing(tensor, cp_size) -+ -+ # Split first if needed. -+ if split_sections is not None: -+ inputs = torch.split(tensor, split_sections, dim=head_dim) -+ outputs = [] -+ for x in inputs: -+ x = tensor_a2a_hp2cp( -+ x, -+ seq_dim=seq_dim, -+ head_dim=head_dim, -+ cp_group=cp_group, -+ redo_attention_load_balancing=False, -+ ) -+ outputs.append(x) -+ tensor = torch.cat(outputs, dim=head_dim) -+ else: -+ tensor = _all_to_all_hp2cp(tensor, cp_group) -+ -+ return tensor -+ -+ -+#################### -+# Torch native gated delta rule -+#################### - def torch_chunk_gated_delta_rule( - query, - key, -@@ -584,6 +944,7 @@ def torch_chunk_gated_delta_rule( - initial_state=None, - output_final_state=False, - use_qk_l2norm_in_kernel=False, -+ cu_seqlens=None, - ): - # pylint: disable=line-too-long - ''' -@@ -593,6 +954,10 @@ def torch_chunk_gated_delta_rule( - Reference: https://github.com/huggingface/transformers/blob/144c8ce2809a2e21914017652700e1ecb450501e/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L470-L547 - ''' - -+ assert ( -+ cu_seqlens is None -+ ), "cu_seqlens is not supported for torch_chunk_gated_delta_rule for now." -+ - initial_dtype = query.dtype - if use_qk_l2norm_in_kernel: - query = l2norm(query, dim=-1, eps=1e-6) -@@ -666,4 +1031,4 @@ def torch_chunk_gated_delta_rule( - ) - core_attn_out = core_attn_out[:, :, :sequence_length] - core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) -- return core_attn_out, last_recurrent_state -+ return core_attn_out, last_recurrent_state diff --git a/docker/patch/latest/megatron.patch b/docker/patch/latest/megatron.patch new file mode 120000 index 00000000..ec9557dc --- /dev/null +++ b/docker/patch/latest/megatron.patch @@ -0,0 +1 @@ +../megatron/20260506-85bced0ae.patch \ No newline at end of file diff --git a/docker/patch/megatron/20251218-3714d81d.patch b/docker/patch/megatron/20251218-3714d81d.patch new file mode 100644 index 00000000..5d4428a5 --- /dev/null +++ b/docker/patch/megatron/20251218-3714d81d.patch @@ -0,0 +1,1578 @@ +diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py +index 41c21d93d..ef80f72d6 100644 +--- a/megatron/core/dist_checkpointing/strategies/common.py ++++ b/megatron/core/dist_checkpointing/strategies/common.py +@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): + msc = MultiStorageClientFeature.import_package() + return msc.torch.load(load_path, map_location='cpu') + else: +- return torch.load(load_path, map_location='cpu') ++ return torch.load(load_path, map_location='cpu', weights_only=False) + except FileNotFoundError as e: + err_msg = f'Common file {load_path} does not exist' + if MultiStorageClientFeature.is_enabled(): +diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py +index 5a1ea308d..aa701237f 100644 +--- a/megatron/core/dist_checkpointing/strategies/torch.py ++++ b/megatron/core/dist_checkpointing/strategies/torch.py +@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + def _validate_global_shapes(self, metadata, sharded_tensors): + for sh_ten in sharded_tensors: + if sh_ten.key not in metadata.state_dict_metadata: +- raise KeyError( +- f"{sh_ten.key} from model not in state dict:" +- f" {sorted(metadata.state_dict_metadata.keys())}" +- ) ++ # raise KeyError( ++ # f"{sh_ten.key} from model not in state dict:" ++ # f" {sorted(metadata.state_dict_metadata.keys())}" ++ # ) ++ print(f"{sh_ten.key} from model not in state dict, will skip") ++ continue + loaded_shape = metadata.state_dict_metadata[sh_ten.key].size + expected_shape = self._expected_shape(sh_ten) + if loaded_shape != expected_shape: +@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + tensor_metadata = self.metadata.state_dict_metadata + metadata_with_sizes = [ + (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) +- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() ++ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata + ] + try: + # Temporarily set sizes to expected shapes +@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): + planner=MCoreLoadPlanner( + shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, + allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, ++ allow_partial_load=True, + ), + ) + +diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py +index acb93ef78..d239db4ab 100644 +--- a/megatron/core/extensions/transformer_engine.py ++++ b/megatron/core/extensions/transformer_engine.py +@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): + ) + + for param in self.parameters(): ++ setattr(param, "parallel_mode", parallel_mode) + if is_expert: + # Reduce the gradient on the expert_data_parallel group for expert linear layers + setattr(param, "allreduce", not self.expert_parallel) +@@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): + + + if HAVE_TE and is_te_min_version("1.9.0.dev0"): ++ def ceil_div(x: int, y: int) -> int: ++ return (x + y - 1) // y ++ ++ class _FakeInt4QuantizationSTE(torch.autograd.Function): ++ @staticmethod ++ def forward(ctx, x, group_size): ++ m, n = x.shape ++ block_size_m, block_size_n = 1, group_size ++ ++ ++ m_padded = ceil_div(m, block_size_m) * block_size_m ++ n_padded = ceil_div(n, block_size_n) * block_size_n ++ ++ x_padded = torch.zeros( ++ (m_padded, n_padded), ++ dtype=x.dtype, device=x.device ++ ) ++ x_padded[:m, :n] = x ++ ++ x_view = x_padded.view( ++ m_padded // block_size_m, ++ block_size_m, ++ n_padded // block_size_n, ++ block_size_n ++ ) ++ ++ x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) ++ q_max = 7 ++ x_scale = x_max / q_max ++ ++ x_scale = x_scale.clamp(min=1e-5) ++ ++ x_div = x_view / x_scale ++ x_round = torch.round(x_div) ++ ++ x_q_clamped = x_round.clamp(-q_max, q_max) ++ ++ x_dequant_view = x_q_clamped * x_scale ++ ++ x_dequant_full = x_dequant_view.view_as(x_padded) ++ x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) ++ ++ return x_out ++ ++ @staticmethod ++ def backward(ctx, grad_output): ++ return grad_output, None ++ ++ def fake_int4_quantization_ste(x, group_size): ++ x_out = _FakeInt4QuantizationSTE.apply(x, group_size) ++ ++ if hasattr(x, 'main_grad'): ++ x_out.main_grad = x.main_grad ++ ++ return x_out + + class TEGroupedLinear(te.pytorch.GroupedLinear): + """ +@@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) ++ + out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + +@@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + return out + return out, None + ++ def _get_weight_tensors(self): ++ """Get the weight tensors of the module.""" ++ weight_tensors = super()._get_weight_tensors() ++ ++ if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": ++ group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) ++ ++ weight_tensors = [ ++ fake_int4_quantization_ste(w, group_size) ++ for w in weight_tensors ++ ] ++ ++ return weight_tensors ++ + def _encode_extra_state(self, state): + # TE 2.0 changed the format of extra_state to be a byte tensor + if is_te_min_version("2.0.0"): +diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +index 1fd5dcfae..c9aeef1f0 100644 +--- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py ++++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( + cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + +- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads +- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads +- mask = kv_off < head_num * stride_kv_nheads +- k_in_off = kv_off + tl.arange(0, k_dim)[None, :] +- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] +- k = tl.load(KV_ptr + k_in_off, mask=mask) +- v = tl.load(KV_ptr + v_in_off, mask=mask) ++ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ k_off = ki_range * stride_kv_nheads + kj_range ++ if v_dim > 0: ++ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ v = tl.load(KV_ptr + v_off, mask=mask_v) ++ else: ++ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) ++ k = tl.load(KV_ptr + k_off, mask=mask_k) + +- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads +- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads ++ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads ++ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads + +- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] +- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] +- tl.store(K_ptr + k_out_off, k, mask=mask) +- tl.store(V_ptr + v_out_off, v, mask=mask) ++ k_out_off = ki_range * stride_k_nheads + kj_range ++ tl.store(K_ptr + k_out_off, k, mask=mask_k) ++ if v_dim > 0: ++ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] ++ tl.store(V_ptr + v_out_off, v, mask=mask_v) + + EMB = K_POS_EMB + pid_m * stride_emb_seq + # x1 = t[..., 0::2], x2 = t[..., 1::2] +@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( + x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + ++ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ mask_x = x_range < head_num + x_left_off = ( +- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads ++ x_range * stride_k_nheads + + k_dim + + tl.arange(0, emb_dim // 2)[None, :] + ) + x_right_off = x_left_off + emb_dim // 2 +- tl.store(K_ptr + x_left_off, x_left, mask=mask) +- tl.store(K_ptr + x_right_off, x_right, mask=mask) ++ tl.store(K_ptr + x_left_off, x_left, mask=mask_x) ++ tl.store(K_ptr + x_right_off, x_right, mask=mask_x) + + + @triton.autotune( +@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( + else: + token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) + +- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads +- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads +- mask = dkv_off < head_num * stride_dkv_nheads +- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] +- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] +- +- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads +- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads +- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] +- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] +- dk = tl.load(dK_ptr + dk_in_off, mask=mask) +- dv = tl.load(dV_ptr + dv_in_off, mask=mask) +- tl.store(dKV_ptr + dk_out_off, dk, mask=mask) +- tl.store(dKV_ptr + dv_out_off, dv, mask=mask) ++ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ dk_out_off = ki_range * stride_dkv_nheads + kj_range ++ ++ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads ++ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads ++ dk_in_off = ki_range * stride_dk_nheads + kj_range ++ ++ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) ++ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) ++ ++ if v_dim > 0: ++ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] ++ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) ++ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) + + if pid_head == 0: + x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): +- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads +- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim ++ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads ++ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads + mask = x_off < head_num * stride_dk_nheads + x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] + x_right_off = x_left_off + emb_dim // 2 +@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) + o_value = kv.new_empty(total_seqlen, nheads, v_dim) ++ k_dim_ceil = triton.next_power_of_2(k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_fwd_kv_kernel[grid]( +@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + emb_dim, + k_dim, ++ k_dim_ceil, + v_dim, + nheads, + batch_size, +@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) + d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) ++ k_dim_ceil = triton.next_power_of_2(ctx.k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_bwd_kv_kernel[grid]( +@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + ctx.emb_dim, + ctx.k_dim, ++ k_dim_ceil, + ctx.v_dim, + nheads, + batch_size, +diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py +index 5d7b69cd3..2e0a26815 100644 +--- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py ++++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py +@@ -348,6 +348,7 @@ class MultimodalRotaryEmbedding(nn.Module): + + # shape (seq_length, bs, 1, 2 * dim) + emb = emb[..., None, :].transpose(0, 1).contiguous() ++ packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: + if packed_seq_params.local_cp_size > 1: + # Set CP group to dynamic CP group for CP slicing +@@ -357,7 +358,9 @@ class MultimodalRotaryEmbedding(nn.Module): + cp_group = None + else: + cp_group = self.cp_group +- if cp_group is not None and cp_group.size() > 1: ++ # For THD (packed sequence) format, skip CP slicing here — it is handled ++ # per-sequence inside _apply_rotary_pos_emb_thd instead (same as RotaryEmbedding). ++ if cp_group is not None and cp_group.size() > 1 and not packed_seq: + # slice rotary_pos_emb along sequence dimension and select the parition of the current + # CP rank + emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) +diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py +index 13d74aa52..060898a7a 100644 +--- a/megatron/core/models/common/language_module/language_module.py ++++ b/megatron/core/models/common/language_module/language_module.py +@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): + assert ( + column_parallel_linear is not None + ), "column_parallel_linear cannot be None when not using fused linear cross entropy." +- logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) ++ # output ++ output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} ++ output_layer_buffers = dict(column_parallel_linear.named_buffers()) ++ logits, _ = torch.func.functional_call( ++ column_parallel_linear, ++ {**output_layer_params, **output_layer_buffers}, ++ (hidden,), ++ col_linear_kwargs, ++ ) + + return self.compute_language_model_loss(labels, logits) + +diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py +index e21127b87..712793853 100755 +--- a/megatron/core/models/gpt/gpt_layer_specs.py ++++ b/megatron/core/models/gpt/gpt_layer_specs.py +@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( + use_kitchen: bool = False, + use_te_activation_func: bool = False, + fallback_to_eager_attn: bool = False, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). + +@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( + mlp=mlp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + normalization=normalization, ++ post_self_attn_layernorm=post_self_attn_layernorm, ++ post_mlp_layernorm=post_mlp_layernorm, + ) + + +@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( + mlp: ModuleSpec, + sharded_state_dict_keys_map: Optional[dict] = None, + normalization: Optional[str] = None, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Helper function to get module spec for TransformerLayer""" + +@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( + input_layernorm=input_layernorm, + self_attention=attention, + self_attn_bda=get_bias_dropout_add, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=pre_mlp_layernorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + ), + ) +diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py +index a1230568c..1fd52f65a 100644 +--- a/megatron/core/models/gpt/gpt_model.py ++++ b/megatron/core/models/gpt/gpt_model.py +@@ -446,6 +446,7 @@ class GPTModel(LanguageModule): + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, ++ mtp_kwargs: Optional[dict] = {}, + ) -> Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoder and finally into the post +@@ -508,6 +509,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, ++ mtp_kwargs=mtp_kwargs, + ) + + def _postprocess( +@@ -529,6 +531,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, ++ mtp_kwargs={}, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + +@@ -543,7 +546,8 @@ class GPTModel(LanguageModule): + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() +- if mtp_in_postprocess: ++ ++ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, +@@ -563,13 +567,18 @@ class GPTModel(LanguageModule): + return hidden_states + + # Skip when mtp_num_layers is None or 0 +- if self.config.mtp_num_layers: +- mtp_labels = labels.clone() ++ if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: ++ mtp_labels = mtp_kwargs['mtp_labels'].clone() ++ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) ++ else: ++ # Otherwise, roll the loss_mask to keep up with the mtp_labels ++ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + for mtp_layer_number in range(self.config.mtp_num_layers): + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor( +@@ -595,7 +604,7 @@ class GPTModel(LanguageModule): + sequence_parallel_enabled=self.output_layer.sequence_parallel, + column_parallel_linear=self.output_layer, + col_linear_kwargs={ +- 'weight': output_weight, ++ 'weight': output_weight.detach() if output_weight else None, + 'runtime_gather_output': runtime_gather_output, + }, + ) +diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py +index 6e093f96f..eac21a3ea 100644 +--- a/megatron/core/optimizer/distrib_optimizer.py ++++ b/megatron/core/optimizer/distrib_optimizer.py +@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + # TE FusedAdam will not accumulate step for empty param groups, so we need to + # align the step across param groups. + param_group["step"] = int(step) ++ if "step" in param_group and param_group["step"] is None: ++ del param_group["step"] + + # Grad scaler state. + if self.grad_scaler: +@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + if key == 'padding': + tensors[key] = LocalNonpersistentObject(tensors[key]) + continue ++ if key == 'step': ++ continue + assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( + tensors[key].shape, + gbuf_local_start, +diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py +index a273002b9..4f821cfd5 100644 +--- a/megatron/core/parallel_state.py ++++ b/megatron/core/parallel_state.py +@@ -11,6 +11,7 @@ from typing import Callable, List, Optional + + import numpy as np + import torch ++import torch.distributed as dist + + from .utils import GlobalMemoryBuffer, is_torch_min_version + +diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py +index ac839c21f..f18309217 100644 +--- a/megatron/core/pipeline_parallel/p2p_communication.py ++++ b/megatron/core/pipeline_parallel/p2p_communication.py +@@ -26,22 +26,22 @@ def _batched_p2p_ops( + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group ++ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, + ) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, + ) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group ++ torch.distributed.isend, tensor_send_next, next_pipeline_rank, + ) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, + ) + ops.append(recv_next_op) + if len(ops) > 0: +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index 28cff06f5..58dc4bb70 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -587,6 +587,9 @@ def topk_routing_with_score_function( + else: + return torch.topk(scores, k=topk, dim=1) + ++ from relax.utils.training.routing_replay import get_routing_replay_compute_topk ++ compute_topk = get_routing_replay_compute_topk(compute_topk) ++ + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) +diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py +index 16fc9d9af..517944f25 100644 +--- a/megatron/core/transformer/moe/router.py ++++ b/megatron/core/transformer/moe/router.py +@@ -201,6 +201,9 @@ class TopKRouter(Router): + self.global_tokens_per_expert = None + self.ga_steps = None + ++ from relax.utils.training.routing_replay import register_routing_replay ++ register_routing_replay(self) ++ + def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. +diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py +index a8f4abfcd..f33f6f05e 100755 +--- a/megatron/core/transformer/multi_token_prediction.py ++++ b/megatron/core/transformer/multi_token_prediction.py +@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union + + import torch + from torch import Tensor ++import warnings + + from megatron.core import InferenceParams, parallel_state, tensor_parallel + from megatron.core.dist_checkpointing.mapping import ShardedStateDict +@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) +- position_ids, _ = roll_tensor( +- position_ids, +- shifts=-1, +- dims=-1, +- cp_group=self.cp_group, +- packed_seq_params=packed_seq_params, +- ) ++ if position_ids is not None: ++ position_ids, _ = roll_tensor( ++ position_ids, ++ shifts=-1, ++ dims=-1, ++ cp_group=self.cp_group, ++ packed_seq_params=packed_seq_params, ++ ) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) ++ decoder_input = decoder_input.detach() + +- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) ++ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) + + return input_ids, position_ids, decoder_input, hidden_states + +@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): + return hidden_states + + def _checkpointed_forward(self, forward_func, *args, **kwargs): ++ """Wrap `forward_func` with activation checkpointing while only passing tensors. ++ ++ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so ++ that checkpoint implementations never receive them directly, avoiding save_for_backward ++ issues with non-tensor inputs. ++ """ ++ ++ # TODO(jiajun): Is there any better implementation here? ++ positional_specs = [] ++ kw_specs = [] ++ tensor_args: List[torch.Tensor] = [] ++ ++ for arg in args: ++ if torch.is_tensor(arg): ++ positional_specs.append(('tensor', len(tensor_args))) ++ tensor_args.append(arg) ++ else: ++ positional_specs.append(('const', arg)) ++ ++ for key, value in kwargs.items(): ++ if torch.is_tensor(value): ++ kw_specs.append((key, ('tensor', len(tensor_args)))) ++ tensor_args.append(value) ++ else: ++ kw_specs.append((key, ('const', value))) ++ ++ def run(*flat_tensor_args): ++ rebuilt_args = [] ++ for spec_type, payload in positional_specs: ++ if spec_type == 'tensor': ++ rebuilt_args.append(flat_tensor_args[payload]) ++ else: ++ rebuilt_args.append(payload) ++ ++ rebuilt_kwargs = {} ++ for key, (spec_type, payload) in kw_specs: ++ if spec_type == 'tensor': ++ rebuilt_kwargs[key] = flat_tensor_args[payload] ++ else: ++ rebuilt_kwargs[key] = payload ++ ++ return forward_func(*rebuilt_args, **rebuilt_kwargs) ++ ++ tensor_args_tuple = tuple(tensor_args) ++ + def checkpoint_handler(): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: +@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), +- *args, +- **kwargs, ++ *tensor_args_tuple, + ) + else: + return tensor_parallel.checkpoint( +- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() ++ run, self.config.distribute_saved_activations, *tensor_args_tuple + ) + + if self.config.recompute_method == 'uniform': +diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py +index e2705bd9f..a0aa109b5 100644 +--- a/megatron/core/transformer/transformer_config.py ++++ b/megatron/core/transformer/transformer_config.py +@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): + attention_output_gate: bool = False + """Whether to apply output gate to the attention layers.""" + ++ post_self_attn_layernorm: bool = False ++ post_mlp_layernorm: bool = False ++ + test_mode: bool = False + """Whether to run real-time tests.""" + +diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py +index 3ea405770..5a42001b9 100644 +--- a/megatron/core/transformer/transformer_layer.py ++++ b/megatron/core/transformer/transformer_layer.py +@@ -223,6 +223,7 @@ class TransformerLayerSubmodules: + input_layernorm: Union[ModuleSpec, type] = IdentityOp + self_attention: Union[ModuleSpec, type] = IdentityOp + self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + + pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + cross_attention: Union[ModuleSpec, type] = IdentityOp +@@ -231,6 +232,7 @@ class TransformerLayerSubmodules: + pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + mlp: Union[ModuleSpec, type] = IdentityOp + mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + + # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method + sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) +@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + ++ self.post_self_attn_layernorm = build_module( ++ submodules.post_self_attn_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon, ++ ) ++ + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = build_module( + submodules.pre_cross_attn_layernorm, +@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + + self.is_moe_layer = isinstance(self.mlp, MoELayer) + ++ self.post_mlp_layernorm = build_module( ++ submodules.post_mlp_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon ++ ) ++ + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False +@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + attention_output_with_bias[0] + ) + ++ attention_output, attention_output_bias = attention_output_with_bias ++ attention_output = self.post_self_attn_layernorm(attention_output) ++ attention_output_with_bias = (attention_output, attention_output_bias) ++ + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + nvtx_range_push(suffix="self_attn_bda") +@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + else: + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + ++ mlp_output, mlp_output_bias = mlp_output_with_bias ++ mlp_output = self.post_mlp_layernorm(mlp_output) ++ mlp_output_with_bias = (mlp_output, mlp_output_bias) ++ + if self.recompute_pre_mlp_layernorm: + # discard the output of the pre-mlp layernorm and register the recompute + # as a gradient hook of mlp_output_with_bias[0] +diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py +index b267c8a81..83736acdc 100644 +--- a/megatron/training/arguments.py ++++ b/megatron/training/arguments.py +@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): + + kw_args['inference_sampling_seed'] = args.seed + ++ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm ++ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm ++ + # handle quantization config + # NOTE: Kitchen arguments are only added to the namespace when + # Kitchen library is available. +@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): + action='store_true', + help='If set, use original BERT residula connection ' + 'ordering.') ++ group.add_argument('--post-self-attn-layernorm', action='store_true', ++ help='If set, use post self attention layernorm.') ++ group.add_argument('--post-mlp-layernorm', action='store_true', ++ help='If set, use post MLP layernorm.') ++ group.add_argument('--use-gated-attention', action='store_true', ++ help='If set, use gated attention as in Qwen3Next') + group.add_argument('--openai-gelu', action='store_true', + help='Use OpenAIs GeLU implementation. This option' + 'should not be used unless for backward compatibility' +diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py +index 13b7526ca..6c590f653 100644 +--- a/megatron/training/tokenizer/tokenizer.py ++++ b/megatron/training/tokenizer/tokenizer.py +@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): + # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, +- trust_remote_code=trust_remote_code, ++ trust_remote_code=True, + **kwargs, + ) + self._vocab = self._tokenizer.get_vocab() +diff --git a/megatron/core/ssm/gated_delta_net.py b/megatron/core/ssm/gated_delta_net.py +index dfa6e4c35..0b38f1135 100644 +--- a/megatron/core/ssm/gated_delta_net.py ++++ b/megatron/core/ssm/gated_delta_net.py +@@ -21,6 +21,12 @@ from megatron.core.inference.contexts import BaseInferenceContext + from megatron.core.jit import jit_fuser + from megatron.core.packed_seq_params import PackedSeqParams + from megatron.core.process_groups_config import ProcessGroupCollection ++from megatron.core.ssm.mamba_context_parallel import ( ++ _all_to_all_cp2hp, ++ _all_to_all_hp2cp, ++ _redo_attention_load_balancing, ++ _undo_attention_load_balancing, ++) + from megatron.core.tensor_parallel import get_cuda_rng_tracker + from megatron.core.transformer import TransformerConfig + from megatron.core.transformer.identity_op import IdentityOp +@@ -33,24 +39,19 @@ from megatron.core.transformer.utils import ( + ) + from megatron.core.utils import deprecate_inference_params, nvtx_range_pop, nvtx_range_push + +-# TODO: Implement GatedDeltaNetContextParallel +-# from .gated_delta_net_context_parallel import GatedDeltaNetContextParallel +- + try: ++ from fla.modules.convolution import causal_conv1d + from fla.modules.l2norm import l2norm + from fla.ops.gated_delta_rule import chunk_gated_delta_rule + + HAVE_FLA = True + except ImportError: ++ causal_conv1d = None ++ l2norm = None + chunk_gated_delta_rule = None + + HAVE_FLA = False + +-try: +- from causal_conv1d import causal_conv1d_fn +-except ImportError: +- causal_conv1d_fn = None +- + + logger = logging.getLogger(__name__) + +@@ -84,6 +85,7 @@ class GatedDeltaNet(MegatronModule): + use_qk_l2norm: bool = True, + A_init_range: Tuple[float, float] = (1, 16), + pg_collection: ProcessGroupCollection = None, ++ **kwargs, + ): + """ + Args: +@@ -98,9 +100,11 @@ class GatedDeltaNet(MegatronModule): + pg_collection: The required process groups to use for tensor model parallel and context + parallel. + """ +- ++ # print(f"new gdn", flush=True) + if not HAVE_FLA: +- raise ImportError("FLA is not installed. Please install it with `pip install fla`.") ++ raise ImportError( ++ "FLA is not installed. Please install it with `pip install flash-linear-attention`." ++ ) + + super().__init__(config) + +@@ -114,6 +118,7 @@ class GatedDeltaNet(MegatronModule): + self.use_qk_l2norm = use_qk_l2norm + assert pg_collection is not None, "pg_collection must be provided for GatedDeltaNet" + self.pg_collection = pg_collection ++ self.cp_size = self.pg_collection.cp.size() + self.tp_size = self.pg_collection.tp.size() + self.sp_size = self.tp_size if config.sequence_parallel else 1 + +@@ -129,6 +134,8 @@ class GatedDeltaNet(MegatronModule): + self.num_value_heads = config.linear_num_value_heads + self.qk_dim = self.key_head_dim * self.num_key_heads + self.v_dim = self.value_head_dim * self.num_value_heads ++ self.qk_dim_local_tp = self.qk_dim // self.tp_size ++ self.v_dim_local_tp = self.v_dim // self.tp_size + + # Input projection (hidden_states -> q, k, v, gate, beta, alpha) + # TODO: for now, output gate is forced for GDN. +@@ -171,8 +178,10 @@ class GatedDeltaNet(MegatronModule): + dtype=config.params_dtype, + ) + setattr(self.conv1d.weight, "tensor_model_parallel", True) ++ setattr(self.conv1d.weight, "partition_dim", 0) + if conv_bias: + setattr(self.conv1d.bias, "tensor_model_parallel", True) ++ setattr(self.conv1d.bias, "partition_dim", 0) + + # Time step projection (discretization) + self.num_v_heads_local_tp = self.num_value_heads // self.tp_size +@@ -185,6 +194,7 @@ class GatedDeltaNet(MegatronModule): + ) + ) + setattr(self.dt_bias, "tensor_model_parallel", True) ++ setattr(self.dt_bias, "partition_dim", 0) + # A_log parameter + self.A_log = nn.Parameter( + torch.empty( +@@ -194,6 +204,12 @@ class GatedDeltaNet(MegatronModule): + ) + ) + setattr(self.A_log, "tensor_model_parallel", True) ++ setattr(self.A_log, "partition_dim", 0) ++ ++ if self.config.deterministic_mode: ++ self.gated_delta_rule = torch_chunk_gated_delta_rule ++ else: ++ self.gated_delta_rule = chunk_gated_delta_rule + + # Output layernorm before projection + self.out_norm = build_module( +@@ -217,8 +233,6 @@ class GatedDeltaNet(MegatronModule): + tp_group=self.pg_collection.tp, + ) + +- # TODO: support CP +- + self.reset_parameters() + + def reset_parameters(self): +@@ -241,23 +255,18 @@ class GatedDeltaNet(MegatronModule): + dtype=self.config.params_dtype, + device=torch.cuda.current_device(), + ).uniform_(*self.A_init_range) +- self.A_log.data.copy_(A) ++ self.A_log.data.copy_(torch.log(A)) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, +- key_value_states: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, +- rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, +- rotary_pos_cos: Optional[Tensor] = None, +- rotary_pos_sin: Optional[Tensor] = None, +- rotary_pos_cos_sin: Optional[Tensor] = None, +- attention_bias: Optional[Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, ++ **kwargs, + ): + """ + Perform a forward pass through the GDN module. +@@ -265,15 +274,8 @@ class GatedDeltaNet(MegatronModule): + Args: + hidden_states (Tensor): Hidden states. + attention_mask (Tensor): Attention mask. +- key_value_states (Optional[Tensor]): Key/value states (for cross attention). + inference_context (Optional[BaseInferenceContext]): Inference context that manages + KV cache. +- rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary +- embedding tensor(s). +- rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. +- rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. +- rotary_pos_cos_sin (Optional[Tensor]): Combined rotary embedding cosine and sine. +- attention_bias (Optional[Tensor]): Attention bias. + packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format. + sequence_len_offset (Optional[int]): Sequence length offset used for + inference CUDA graphs. +@@ -287,7 +289,7 @@ class GatedDeltaNet(MegatronModule): + inference_context = deprecate_inference_params(inference_context, inference_params) + + seq_len, batch, _ = hidden_states.shape +- seq_len = seq_len * self.sp_size ++ seq_len = seq_len * self.sp_size * self.cp_size + + if inference_context is not None: + assert ( +@@ -297,15 +299,80 @@ class GatedDeltaNet(MegatronModule): + # TODO: support inference + raise NotImplementedError("GDN does not support inference for now.") + +- if packed_seq_params is not None: +- # TODO: support packed sequence +- raise NotImplementedError("GDN does not support packed sequence for now.") ++ if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': ++ assert batch == 1, "Packed sequence expects batch dimension to be 1" ++ assert ( ++ not self.config.deterministic_mode ++ ), "Packed sequence does not support deterministic mode." ++ ++ # Resolve cu_seqlens with alignment padding handling. ++ cu_seqlens_q = self._resolve_cu_seqlens( ++ packed_seq_params.cu_seqlens_q_padded, ++ packed_seq_params.cu_seqlens_q, ++ seq_len, ++ "cu_seqlens_q", ++ ) ++ cu_seqlens_kv = self._resolve_cu_seqlens( ++ packed_seq_params.cu_seqlens_kv_padded, ++ packed_seq_params.cu_seqlens_kv, ++ seq_len, ++ "cu_seqlens_kv", ++ ) ++ assert torch.equal(cu_seqlens_q, cu_seqlens_kv), ( ++ "Currently only support cu_seqlens_q equals to cu_seqlens_kv, " ++ f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}" ++ ) ++ num_packed_seqs = cu_seqlens_q.shape[0] - 1 ++ assert num_packed_seqs > 0, ( ++ "Number of packed sequences must be greater than 0, " ++ f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}" ++ ) ++ else: ++ cu_seqlens_q = None ++ cu_seqlens_kv = None + + # Input projection + nvtx_range_push(suffix="in_proj") + qkvzba, _ = self.in_proj(hidden_states) + nvtx_range_pop(suffix="in_proj") + ++ # CP All to All: CP to HP ++ if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': ++ unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens_q // self.cp_size, dim=0) ++ outputs = [] ++ for qkvzba_i in unpacked_qkvzba: ++ qkvzba_i = tensor_a2a_cp2hp( ++ qkvzba_i, ++ seq_dim=0, ++ head_dim=-1, ++ cp_group=self.pg_collection.cp, ++ split_sections=[ ++ self.qk_dim_local_tp, ++ self.qk_dim_local_tp, ++ self.v_dim_local_tp, ++ self.v_dim_local_tp, ++ self.num_value_heads // self.tp_size, ++ self.num_value_heads // self.tp_size, ++ ], ++ ) ++ outputs.append(qkvzba_i) ++ qkvzba = torch.cat(outputs, dim=0) ++ else: ++ qkvzba = tensor_a2a_cp2hp( ++ qkvzba, ++ seq_dim=0, ++ head_dim=-1, ++ cp_group=self.pg_collection.cp, ++ split_sections=[ ++ self.qk_dim_local_tp, ++ self.qk_dim_local_tp, ++ self.v_dim_local_tp, ++ self.v_dim_local_tp, ++ self.num_value_heads // self.tp_size, ++ self.num_value_heads // self.tp_size, ++ ], ++ ) ++ + # Transpose: s b x --> b s x + # From sbhd to bshd format + qkvzba = qkvzba.transpose(0, 1) +@@ -314,10 +381,10 @@ class GatedDeltaNet(MegatronModule): + qkv, gate, beta, alpha = torch.split( + qkvzba, + [ +- (self.qk_dim * 2 + self.v_dim) // self.tp_size, +- self.v_dim // self.tp_size, +- self.num_value_heads // self.tp_size, +- self.num_value_heads // self.tp_size, ++ (self.qk_dim_local_tp * 2 + self.v_dim_local_tp) // self.cp_size, ++ self.v_dim_local_tp // self.cp_size, ++ self.num_value_heads // self.tp_size // self.cp_size, ++ self.num_value_heads // self.tp_size // self.cp_size, + ], + dim=-1, + ) +@@ -326,74 +393,83 @@ class GatedDeltaNet(MegatronModule): + alpha = alpha.reshape(batch, seq_len, -1) + + # Convolution on qkv +- qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s + nvtx_range_push(suffix="conv1d") +- if (causal_conv1d_fn is None) or self.config.deterministic_mode: +- qkv = self.act_fn(self.conv1d(qkv)[..., :seq_len]) ++ seq_len = qkv.shape[1] ++ qkv_channels_split_sections = [ ++ self.qk_dim_local_tp, ++ self.qk_dim_local_tp, ++ self.v_dim_local_tp, ++ ] ++ conv1d_weight = get_parameter_local_cp( ++ self.conv1d.weight, ++ dim=0, ++ cp_group=self.pg_collection.cp, ++ split_sections=qkv_channels_split_sections, ++ ) ++ conv1d_bias = ( ++ get_parameter_local_cp( ++ self.conv1d.bias, ++ dim=0, ++ cp_group=self.pg_collection.cp, ++ split_sections=qkv_channels_split_sections, ++ ) ++ if self.conv_bias ++ else None ++ ) ++ if self.config.deterministic_mode: ++ qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s ++ conv_out = F.conv1d( ++ input=qkv, # Torch-native only accept [b, d, s] format input ++ weight=conv1d_weight, ++ bias=conv1d_bias, ++ stride=self.conv1d.stride, ++ padding=self.conv1d.padding, ++ dilation=self.conv1d.dilation, ++ groups=self.conv_dim_local_tp // self.cp_size, ++ ) ++ qkv = self.act_fn(conv_out[..., :seq_len]) ++ qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d + else: + assert self.activation in ["silu", "swish"] +- qkv = causal_conv1d_fn( +- x=qkv, +- weight=self.conv1d.weight.squeeze(1), # d, 1, w -> d, w +- bias=self.conv1d.bias, ++ qkv, _ = causal_conv1d( ++ x=qkv, # FLA conv1d accepts [b, s, d] format input ++ weight=conv1d_weight.squeeze(1), # d, 1, w -> d, w ++ bias=conv1d_bias, + activation=self.activation, ++ initial_state=None, ++ output_final_state=False, ++ cu_seqlens=cu_seqlens_q, + ) + nvtx_range_pop(suffix="conv1d") +- # Split qkv into query, key, and value +- qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d +- query, key, value = torch.split( +- qkv, +- [self.qk_dim // self.tp_size, self.qk_dim // self.tp_size, self.v_dim // self.tp_size], +- dim=-1, +- ) +- query = query.reshape(batch, seq_len, -1, self.key_head_dim) +- key = key.reshape(batch, seq_len, -1, self.key_head_dim) +- value = value.reshape(batch, seq_len, -1, self.value_head_dim) +- # Apply L2 norm to query and key +- if self.use_qk_l2norm: +- query = l2norm(query.contiguous()) +- key = l2norm(key.contiguous()) +- if self.num_value_heads // self.num_key_heads > 1: +- query = query.repeat_interleave(self.num_value_heads // self.num_key_heads, dim=2) +- key = key.repeat_interleave(self.num_value_heads // self.num_key_heads, dim=2) + +- # Make contiguous +- query = query.contiguous() +- key = key.contiguous() +- value = value.contiguous() +- gate = gate.contiguous() +- beta = beta.contiguous() +- alpha = alpha.contiguous() ++ # Prepare QKV tensors (split, reshape, L2 norm, repeat_interleave, contiguous) ++ nvtx_range_push(suffix="prepare_qkv_for_gated_delta_rule") ++ query, key, value, gate, beta, alpha = self._prepare_qkv_for_gated_delta_rule( ++ qkv, gate, beta, alpha, batch, seq_len ++ ) ++ nvtx_range_pop(suffix="prepare_qkv_for_gated_delta_rule") + + # Calculate g and beta + nvtx_range_push(suffix="g_and_beta") +- g = -self.A_log.exp() * F.softplus(alpha.float() + self.dt_bias) # In fp32 +- beta = beta.sigmoid() ++ A_log_local_cp = get_parameter_local_cp(self.A_log, dim=0, cp_group=self.pg_collection.cp) ++ dt_bias_local_cp = get_parameter_local_cp( ++ self.dt_bias, dim=0, cp_group=self.pg_collection.cp ++ ) ++ g, beta = self._compute_g_and_beta(A_log_local_cp, dt_bias_local_cp, alpha, beta) + nvtx_range_pop(suffix="g_and_beta") + + nvtx_range_push(suffix="gated_delta_rule") +- if self.config.deterministic_mode: +- core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule( +- query, +- key, +- value, +- g=g, +- beta=beta, +- initial_state=None, +- output_final_state=False, +- use_qk_l2norm_in_kernel=False, +- ) +- else: +- core_attn_out, last_recurrent_state = chunk_gated_delta_rule( +- query, +- key, +- value, +- g=g, +- beta=beta, +- initial_state=None, +- output_final_state=False, +- use_qk_l2norm_in_kernel=False, +- ) ++ core_attn_out, last_recurrent_state = self.gated_delta_rule( ++ query, ++ key, ++ value, ++ g=g, ++ beta=beta, ++ initial_state=None, ++ output_final_state=False, ++ use_qk_l2norm_in_kernel=False, ++ cu_seqlens=cu_seqlens_q, ++ ) + nvtx_range_pop(suffix="gated_delta_rule") + + # RMSNorm +@@ -406,6 +482,21 @@ class GatedDeltaNet(MegatronModule): + norm_out = norm_out.reshape(batch, seq_len, -1) + norm_out = norm_out.transpose(0, 1).contiguous() + ++ # CP all to all: HP to CP ++ if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': ++ unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens_q, dim=0) ++ outputs = [] ++ for norm_out_i in unpacked_norm_out: ++ norm_out_i = tensor_a2a_hp2cp( ++ norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp ++ ) ++ outputs.append(norm_out_i) ++ norm_out = torch.cat(outputs, dim=0) ++ else: ++ norm_out = tensor_a2a_hp2cp( ++ norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp ++ ) ++ + # Output projection + nvtx_range_push(suffix="out_proj") + out, out_bias = self.out_proj(norm_out) +@@ -425,6 +516,74 @@ class GatedDeltaNet(MegatronModule): + y = y.to(x_dtype) + return y + ++ @jit_fuser ++ def _prepare_qkv_for_gated_delta_rule(self, qkv, gate, beta, alpha, batch, seq_len): ++ """ ++ Prepare query, key, value, gate, beta, alpha tensors for gated delta rule. ++ Fuses split, reshape, L2 norm, repeat_interleave, and contiguous operations. ++ """ ++ # Split qkv into query_key and value ++ query_key, value = torch.split( ++ qkv, ++ [2 * self.qk_dim_local_tp // self.cp_size, self.v_dim_local_tp // self.cp_size], ++ dim=-1, ++ ) ++ ++ # Reshape query_key and value ++ query_key = query_key.reshape(batch, seq_len, -1, self.key_head_dim) ++ value = value.reshape(batch, seq_len, -1, self.value_head_dim) ++ ++ # Apply L2 norm to query and key ++ if self.use_qk_l2norm: ++ query_key = l2norm(query_key.contiguous()) ++ ++ # Split query and key ++ split_size = self.qk_dim_local_tp // self.key_head_dim // self.cp_size ++ query, key = torch.split(query_key, [split_size, split_size], dim=2) ++ ++ # Expand query and key if needed (grouped query attention) ++ if self.num_value_heads // self.num_key_heads > 1: ++ repeat_factor = self.num_value_heads // self.num_key_heads ++ query = query.repeat_interleave(repeat_factor, dim=2) ++ key = key.repeat_interleave(repeat_factor, dim=2) ++ ++ # Make all tensors contiguous ++ query = query.contiguous() ++ key = key.contiguous() ++ value = value.contiguous() ++ gate = gate.contiguous() ++ beta = beta.contiguous() ++ alpha = alpha.contiguous() ++ ++ return query, key, value, gate, beta, alpha ++ ++ @jit_fuser ++ def _compute_g_and_beta(self, A_log_local_cp, dt_bias_local_cp, alpha, beta): ++ """ ++ Compute g (decay) and beta (sigmoid) for gated delta rule. ++ Fuses exp, softplus, mul, neg, and sigmoid operations. ++ """ ++ g = -A_log_local_cp.exp() * F.softplus(alpha.float() + dt_bias_local_cp) # In fp32 ++ beta = beta.sigmoid() ++ return g, beta ++ ++ def _resolve_cu_seqlens(self, cu_seqlens_padded, cu_seqlens_actual, total_seq_len, name): ++ """Resolve cu_seqlens for packed sequence all-to-all, handling alignment padding.""" ++ if cu_seqlens_padded is not None: ++ cu_seqlens = cu_seqlens_padded ++ else: ++ cu_seqlens = cu_seqlens_actual ++ ++ total_cu = cu_seqlens[-1].item() ++ if total_cu != total_seq_len: ++ raise ValueError( ++ f"GDN: {name}[-1]={total_cu} does not match " ++ f"total_sequence_length={total_seq_len}. " ++ f"({cu_seqlens_padded=}, {cu_seqlens_actual=})." ++ ) ++ ++ return cu_seqlens ++ + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_group=None): + """Provide a sharded state dictionary for distributed checkpointing.""" + # Guard for cases metadata is not provided +@@ -479,10 +638,10 @@ class GatedDeltaNet(MegatronModule): + sharded_state_dict[f"{prefix}in_proj.weight"] = _split_tensor_factory( + sharded_state_dict[f"{prefix}in_proj.weight"], + [ +- self.qk_dim // self.tp_size, +- self.qk_dim // self.tp_size, +- self.v_dim // self.tp_size, +- self.v_dim // self.tp_size, ++ self.qk_dim_local_tp, ++ self.qk_dim_local_tp, ++ self.v_dim_local_tp, ++ self.v_dim_local_tp, + self.num_value_heads // self.tp_size, + self.num_value_heads // self.tp_size, + ], +@@ -502,18 +661,41 @@ class GatedDeltaNet(MegatronModule): + for conv_layer_name in conv_layer_name_list: + sharded_state_dict[f"{prefix}{conv_layer_name}"] = _split_tensor_factory( + sharded_state_dict[f"{prefix}{conv_layer_name}"], +- [ +- self.qk_dim // self.tp_size, +- self.qk_dim // self.tp_size, +- self.v_dim // self.tp_size, +- ], ++ [self.qk_dim_local_tp, self.qk_dim_local_tp, self.v_dim_local_tp], + ["query", "key", "value"], + 0, + ) + + return sharded_state_dict + ++ def backward_dw(self): ++ """Execute weight gradient computation for all linear layers.""" ++ self._backward_in_proj() ++ self._backward_out_proj() ++ ++ def _backward_in_proj(self): ++ """Computes weight gradients of input projection layer.""" ++ self.in_proj.backward_dw() ++ ++ def _backward_out_proj(self): ++ """Computes weight gradients of output projection layer.""" ++ self.out_proj.backward_dw() ++ ++ ++def _unpack_sequence(x, cu_seqlens, dim=1): ++ unpacked_x = [] ++ num_seqs = cu_seqlens.shape[0] - 1 ++ for i in range(num_seqs): ++ idx_start = cu_seqlens[i].item() ++ idx_end = cu_seqlens[i + 1].item() ++ chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)] ++ unpacked_x.append(x[tuple(chunked_index)]) ++ return unpacked_x ++ + ++#################### ++# Sharded state dict utilities ++#################### + def _split_tensor_factory( + orig_sh_ten: ShardedTensor, split_sections: List[int], split_names: List[str], split_dim: int + ) -> ShardedTensorFactory: +@@ -574,6 +756,184 @@ def _split_tensor_factory( + ) + + ++#################### ++# Context parallel utilities ++#################### ++def get_parameter_local_cp( ++ param: torch.Tensor, ++ dim: int, ++ cp_group: torch.distributed.ProcessGroup, ++ split_sections: Optional[List[int]] = None, ++) -> torch.Tensor: ++ """Get the local parameter for the current context parallel rank. ++ ++ Args: ++ param (torch.Tensor): The entire parameter to get the local parameter for. ++ dim (int): The dimension to split the parameter along. Usually the dimension of head. ++ cp_group (torch.distributed.ProcessGroup): The context parallel group. ++ split_sections (Optional[List[int]]): If not None, ++ first split the parameter along the dimension dim into sections, ++ then get the local hidden parallel weights separately, ++ finally concatenate the local hidden parallel weights along the dimension dim. ++ ++ Returns: ++ torch.Tensor: The local parameter for the current context parallel rank. ++ """ ++ ++ cp_size = cp_group.size() ++ cp_rank = cp_group.rank() ++ ++ # No need to split if CP size is 1. ++ if cp_size == 1: ++ return param ++ ++ # Split first if needed. ++ if split_sections is not None: ++ inputs = torch.split(param, split_sections, dim=dim) ++ outputs = [] ++ for p in inputs: ++ p = get_parameter_local_cp(p, dim, cp_group) ++ outputs.append(p) ++ return torch.cat(outputs, dim=dim) ++ ++ # Slice the parameter. ++ slices = [slice(None)] * param.dim() ++ dim_size = param.size(dim=dim) ++ slices[dim] = slice(cp_rank * dim_size // cp_size, (cp_rank + 1) * dim_size // cp_size) ++ param = param[slices] ++ return param ++ ++ ++def tensor_a2a_cp2hp( ++ tensor: torch.Tensor, ++ seq_dim: int, ++ head_dim: int, ++ cp_group: torch.distributed.ProcessGroup, ++ split_sections: Optional[List[int]] = None, ++ undo_attention_load_balancing: bool = True, ++): ++ """All-to-all context parallel to hidden parallel. ++ ++ Args: ++ tensor (torch.Tensor): The tensor to all-to-all. ++ Currently only support (seq_len, batch, head_dim) shaped tensor. ++ seq_dim (int): The dimension of sequence length. Currently only supports seq_dim == 0. ++ head_dim (int): The dimension of head. Currently only supports head_dim == -1 or 2. ++ cp_group (torch.distributed.ProcessGroup): The context parallel group. ++ split_sections (Optional[List[int]]): If not None, split the tensor along the dimension ++ head_dim into sections first, then do all-to-all for each section separately, ++ finally concatenate the separated tensors along the dimension head_dim. ++ undo_attention_load_balancing (bool): Whether to undo the attention load balancing of CP. ++ ++ Returns: ++ torch.Tensor: The all-to-all tensor. ++ """ ++ ++ cp_size = cp_group.size() ++ ++ # No need to all-to-all if CP size is 1. ++ if cp_size == 1: ++ return tensor ++ ++ # Limitations of mamba_context_parallel._all_to_all_cp2hp. ++ assert seq_dim == 0, f"tensor_a2a_cp2hp only supports seq_dim == 0 for now, but got {seq_dim=}" ++ assert ( ++ head_dim == -1 or head_dim == 2 ++ ), f"tensor_a2a_cp2hp only supports head_dim == -1 or 2 for now, but got {head_dim=}" ++ assert ( ++ tensor.dim() == 3 ++ ), f"tensor_a2a_cp2hp only supports 3-d input tensor for now, but got {tensor.dim()=}" ++ ++ # Split first if needed. ++ if split_sections is not None: ++ inputs = torch.split(tensor, split_sections, dim=head_dim) ++ outputs = [] ++ for x in inputs: ++ x = tensor_a2a_cp2hp( ++ x, ++ seq_dim=seq_dim, ++ head_dim=head_dim, ++ cp_group=cp_group, ++ undo_attention_load_balancing=False, ++ ) ++ outputs.append(x) ++ tensor = torch.cat(outputs, dim=head_dim) ++ else: ++ tensor = _all_to_all_cp2hp(tensor, cp_group) ++ ++ # Undo attention load balancing last if needed. ++ if undo_attention_load_balancing: ++ tensor = _undo_attention_load_balancing(tensor, cp_size) ++ return tensor ++ ++ ++def tensor_a2a_hp2cp( ++ tensor: torch.Tensor, ++ seq_dim: int, ++ head_dim: int, ++ cp_group: torch.distributed.ProcessGroup, ++ split_sections: Optional[List[int]] = None, ++ redo_attention_load_balancing: bool = True, ++): ++ """All-to-all hidden parallel to context parallel. ++ ++ Args: ++ tensor (torch.Tensor): The tensor to all-to-all. ++ Currently only support (seq_len, batch, head_dim) shaped tensor. ++ seq_dim (int): The dimension of sequence length. Currently only supports seq_dim == 0. ++ head_dim (int): The dimension of head. Currently only supports head_dim == -1 or 2. ++ cp_group (torch.distributed.ProcessGroup): The context parallel group. ++ split_sections (Optional[List[int]]): If not None, first split the tensor along the ++ dimension head_dim into sections, then do all-to-all for each section separately, ++ finally concatenate the separated tensors along the dimension head_dim. ++ redo_attention_load_balancing (bool): Whether to redo the attention load balancing of HP. ++ ++ Returns: ++ torch.Tensor: The all-to-all tensor. ++ """ ++ ++ cp_size = cp_group.size() ++ ++ # No need to all-to-all if CP size is 1. ++ if cp_size == 1: ++ return tensor ++ ++ # Limitations of mamba_context_parallel._all_to_all_hp2cp. ++ assert seq_dim == 0, f"tensor_a2a_cp2hp only supports seq_dim == 0 for now, but got {seq_dim=}" ++ assert ( ++ head_dim == -1 or head_dim == 2 ++ ), f"tensor_a2a_cp2hp only supports head_dim == -1 or 2 for now, but got {head_dim=}" ++ assert ( ++ tensor.dim() == 3 ++ ), f"tensor_a2a_cp2hp only supports 3-d input tensor for now, but got {tensor.dim()=}" ++ ++ # Redo attention load balancing first if needed. ++ if redo_attention_load_balancing: ++ tensor = _redo_attention_load_balancing(tensor, cp_size) ++ ++ # Split first if needed. ++ if split_sections is not None: ++ inputs = torch.split(tensor, split_sections, dim=head_dim) ++ outputs = [] ++ for x in inputs: ++ x = tensor_a2a_hp2cp( ++ x, ++ seq_dim=seq_dim, ++ head_dim=head_dim, ++ cp_group=cp_group, ++ redo_attention_load_balancing=False, ++ ) ++ outputs.append(x) ++ tensor = torch.cat(outputs, dim=head_dim) ++ else: ++ tensor = _all_to_all_hp2cp(tensor, cp_group) ++ ++ return tensor ++ ++ ++#################### ++# Torch native gated delta rule ++#################### + def torch_chunk_gated_delta_rule( + query, + key, +@@ -584,6 +944,7 @@ def torch_chunk_gated_delta_rule( + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, ++ cu_seqlens=None, + ): + # pylint: disable=line-too-long + ''' +@@ -593,6 +954,10 @@ def torch_chunk_gated_delta_rule( + Reference: https://github.com/huggingface/transformers/blob/144c8ce2809a2e21914017652700e1ecb450501e/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L470-L547 + ''' + ++ assert ( ++ cu_seqlens is None ++ ), "cu_seqlens is not supported for torch_chunk_gated_delta_rule for now." ++ + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) +@@ -666,4 +1031,4 @@ def torch_chunk_gated_delta_rule( + ) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) +- return core_attn_out, last_recurrent_state ++ return core_attn_out, last_recurrent_state diff --git a/docker/patch/megatron/20260506-85bced0ae.patch b/docker/patch/megatron/20260506-85bced0ae.patch new file mode 100644 index 00000000..4b600b57 --- /dev/null +++ b/docker/patch/megatron/20260506-85bced0ae.patch @@ -0,0 +1,791 @@ +diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py +index 58e1e563b..abe561d83 100644 +--- a/megatron/core/dist_checkpointing/strategies/torch.py ++++ b/megatron/core/dist_checkpointing/strategies/torch.py +@@ -501,10 +501,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + def _validate_global_shapes(self, metadata, sharded_tensors): + for sh_ten in sharded_tensors: + if sh_ten.key not in metadata.state_dict_metadata: +- raise KeyError( +- f"{sh_ten.key} from model not in state dict:" +- f" {sorted(metadata.state_dict_metadata.keys())}" +- ) ++ # raise KeyError( ++ # f"{sh_ten.key} from model not in state dict:" ++ # f" {sorted(metadata.state_dict_metadata.keys())}" ++ # ) ++ print(f"{sh_ten.key} from model not in state dict, will skip") ++ continue + loaded_shape = metadata.state_dict_metadata[sh_ten.key].size + expected_shape = sh_ten.global_shape + if loaded_shape != expected_shape: +@@ -528,7 +530,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + tensor_metadata = self.metadata.state_dict_metadata + metadata_with_sizes = [ + (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) +- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() ++ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata + ] + try: + # Temporarily set sizes to expected shapes +@@ -865,6 +867,7 @@ class TorchDistLoadShardedStrategy: + planner=MCoreLoadPlanner( + shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, + allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, ++ allow_partial_load=True, + flatten_state_dict=False, + flatten_sharded_tensors=False, + ), +diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py +index 2a82a1e1c..5441f9335 100644 +--- a/megatron/core/extensions/transformer_engine.py ++++ b/megatron/core/extensions/transformer_engine.py +@@ -836,6 +836,7 @@ class TELinear(te.pytorch.Linear): + self.te_quant_params: Optional[TEQuantizationParams] = None + + for param in self.parameters(): ++ setattr(param, "parallel_mode", parallel_mode) + if is_expert: + # Reduce the gradient on the expert_data_parallel group for expert linear layers + setattr(param, "allreduce", not self.expert_parallel) +@@ -1671,6 +1672,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): + + + if HAVE_TE and is_te_min_version("1.9.0.dev0"): ++ def ceil_div(x: int, y: int) -> int: ++ return (x + y - 1) // y ++ ++ class _FakeInt4QuantizationSTE(torch.autograd.Function): ++ @staticmethod ++ def forward(ctx, x, group_size): ++ m, n = x.shape ++ block_size_m, block_size_n = 1, group_size ++ ++ ++ m_padded = ceil_div(m, block_size_m) * block_size_m ++ n_padded = ceil_div(n, block_size_n) * block_size_n ++ ++ x_padded = torch.zeros( ++ (m_padded, n_padded), ++ dtype=x.dtype, device=x.device ++ ) ++ x_padded[:m, :n] = x ++ ++ x_view = x_padded.view( ++ m_padded // block_size_m, ++ block_size_m, ++ n_padded // block_size_n, ++ block_size_n ++ ) ++ ++ x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) ++ q_max = 7 ++ x_scale = x_max / q_max ++ ++ x_scale = x_scale.clamp(min=1e-5) ++ ++ x_div = x_view / x_scale ++ x_round = torch.round(x_div) ++ ++ x_q_clamped = x_round.clamp(-q_max, q_max) ++ ++ x_dequant_view = x_q_clamped * x_scale ++ ++ x_dequant_full = x_dequant_view.view_as(x_padded) ++ x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) ++ ++ return x_out ++ ++ @staticmethod ++ def backward(ctx, grad_output): ++ return grad_output, None ++ ++ def fake_int4_quantization_ste(x, group_size): ++ x_out = _FakeInt4QuantizationSTE.apply(x, group_size) ++ ++ if hasattr(x, 'main_grad'): ++ x_out.main_grad = x.main_grad ++ ++ return x_out + + class TEGroupedLinear(te.pytorch.GroupedLinear): + """ +@@ -1913,6 +1969,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + "amax_history_bwd": torch.cat( + [state["amax_history_bwd"].view(-1, 1) for state in state_list], + dim=1, ++ + ).view(self.fp8_meta["recipe"].amax_history_len, -1), + } + ) +@@ -1990,6 +2047,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): + return out + return out, None + ++ def _get_weight_tensors(self): ++ """Get the weight tensors of the module.""" ++ weight_tensors = super()._get_weight_tensors() ++ ++ if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": ++ group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) ++ ++ weight_tensors = [ ++ fake_int4_quantization_ste(w, group_size) ++ for w in weight_tensors ++ ] ++ ++ return weight_tensors ++ + def _encode_extra_state(self, state): + # TE 2.0 changed the format of extra_state to be a byte tensor + if is_te_min_version("2.0.0"): +diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +index 1fd5dcfae..c9aeef1f0 100644 +--- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py ++++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( + cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + +- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads +- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads +- mask = kv_off < head_num * stride_kv_nheads +- k_in_off = kv_off + tl.arange(0, k_dim)[None, :] +- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] +- k = tl.load(KV_ptr + k_in_off, mask=mask) +- v = tl.load(KV_ptr + v_in_off, mask=mask) ++ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ k_off = ki_range * stride_kv_nheads + kj_range ++ if v_dim > 0: ++ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ v = tl.load(KV_ptr + v_off, mask=mask_v) ++ else: ++ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) ++ k = tl.load(KV_ptr + k_off, mask=mask_k) + +- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads +- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads ++ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads ++ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads + +- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] +- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] +- tl.store(K_ptr + k_out_off, k, mask=mask) +- tl.store(V_ptr + v_out_off, v, mask=mask) ++ k_out_off = ki_range * stride_k_nheads + kj_range ++ tl.store(K_ptr + k_out_off, k, mask=mask_k) ++ if v_dim > 0: ++ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] ++ tl.store(V_ptr + v_out_off, v, mask=mask_v) + + EMB = K_POS_EMB + pid_m * stride_emb_seq + # x1 = t[..., 0::2], x2 = t[..., 1::2] +@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( + x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + ++ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ mask_x = x_range < head_num + x_left_off = ( +- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads ++ x_range * stride_k_nheads + + k_dim + + tl.arange(0, emb_dim // 2)[None, :] + ) + x_right_off = x_left_off + emb_dim // 2 +- tl.store(K_ptr + x_left_off, x_left, mask=mask) +- tl.store(K_ptr + x_right_off, x_right, mask=mask) ++ tl.store(K_ptr + x_left_off, x_left, mask=mask_x) ++ tl.store(K_ptr + x_right_off, x_right, mask=mask_x) + + + @triton.autotune( +@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( + else: + token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) + +- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads +- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads +- mask = dkv_off < head_num * stride_dkv_nheads +- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] +- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] +- +- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads +- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads +- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] +- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] +- dk = tl.load(dK_ptr + dk_in_off, mask=mask) +- dv = tl.load(dV_ptr + dv_in_off, mask=mask) +- tl.store(dKV_ptr + dk_out_off, dk, mask=mask) +- tl.store(dKV_ptr + dv_out_off, dv, mask=mask) ++ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ dk_out_off = ki_range * stride_dkv_nheads + kj_range ++ ++ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads ++ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads ++ dk_in_off = ki_range * stride_dk_nheads + kj_range ++ ++ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) ++ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) ++ ++ if v_dim > 0: ++ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] ++ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) ++ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) + + if pid_head == 0: + x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): +- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads +- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim ++ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads ++ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads + mask = x_off < head_num * stride_dk_nheads + x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] + x_right_off = x_left_off + emb_dim // 2 +@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) + o_value = kv.new_empty(total_seqlen, nheads, v_dim) ++ k_dim_ceil = triton.next_power_of_2(k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_fwd_kv_kernel[grid]( +@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + emb_dim, + k_dim, ++ k_dim_ceil, + v_dim, + nheads, + batch_size, +@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) + d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) ++ k_dim_ceil = triton.next_power_of_2(ctx.k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_bwd_kv_kernel[grid]( +@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + ctx.emb_dim, + ctx.k_dim, ++ k_dim_ceil, + ctx.v_dim, + nheads, + batch_size, +diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py +index 3ff370f74..21858ea6a 100644 +--- a/megatron/core/inference/contexts/dynamic_context.py ++++ b/megatron/core/inference/contexts/dynamic_context.py +@@ -57,7 +57,8 @@ except ImportError: + try: + from torch_memory_saver import torch_memory_saver + +- torch_memory_saver.hook_mode = "torch" ++ # Commented out: breaks SGLang CUDA graph (requires hook_mode="preload") ++ # torch_memory_saver.hook_mode = "torch" + HAVE_TORCH_MEMORY_SAVER = True + except ImportError: + HAVE_TORCH_MEMORY_SAVER = False +diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py +index 92d561c34..d4f62cb75 100755 +--- a/megatron/core/models/gpt/gpt_layer_specs.py ++++ b/megatron/core/models/gpt/gpt_layer_specs.py +@@ -189,6 +189,8 @@ def get_gpt_layer_with_transformer_engine_submodules( + enable_hyper_connection: bool = False, + mla_down_proj_fusion: bool = False, + dense_grouped_gemm: bool = False, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> TransformerLayerSubmodules: + """Use these submodules to use lower-level Transformer Engine modules (required for fp8 + training). +@@ -282,9 +284,11 @@ def get_gpt_layer_with_transformer_engine_submodules( + ), + ), + self_attn_bda=get_bias_dropout_add, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map=( + { + "self_attention.linear_q_down_proj.layer_norm_": "input_layernorm.", +@@ -314,10 +318,12 @@ def get_gpt_layer_with_transformer_engine_submodules( + ), + self_attn_bda=get_bias_dropout_add, + self_attention_hyper_connection=hc_module, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + mlp_hyper_connection=hc_module, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + ) + else: + qk_norm = backend.layer_norm(for_qk=True) +@@ -339,10 +345,12 @@ def get_gpt_layer_with_transformer_engine_submodules( + ), + self_attn_bda=get_bias_dropout_add, + self_attention_hyper_connection=hc_module, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + mlp_hyper_connection=hc_module, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map={ + "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", + "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", +diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py +index 19de0ed52..f2899e542 100644 +--- a/megatron/core/models/gpt/gpt_model.py ++++ b/megatron/core/models/gpt/gpt_model.py +@@ -506,6 +506,7 @@ class GPTModel(LanguageModule): + loss_mask: Optional[Tensor] = None, + padding_mask: Optional[Tensor] = None, + is_spec_decode: Optional[bool] = None, ++ mtp_kwargs: Optional[dict] = {}, + ) -> Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoder and finally into the post +@@ -585,6 +586,7 @@ class GPTModel(LanguageModule): + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, + is_spec_decode=is_spec_decode, ++ mtp_kwargs=mtp_kwargs, + ) + + def _postprocess( +@@ -607,6 +609,7 @@ class GPTModel(LanguageModule): + extra_block_kwargs=None, + inference_context=None, + is_spec_decode=None, ++ mtp_kwargs={}, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + +@@ -630,7 +633,7 @@ class GPTModel(LanguageModule): + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: +- output_weight = self.shared_embedding_or_output_weight() ++ output_weight = self.shared_embedding_or_output_weight().detach() + if mtp_in_postprocess and not (in_inference_mode or is_spec_decode): + hidden_states = self.mtp( + input_ids=input_ids, +diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py +index 4430a8c84..084389b27 100644 +--- a/megatron/core/optimizer/distrib_optimizer.py ++++ b/megatron/core/optimizer/distrib_optimizer.py +@@ -706,6 +706,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + # TE FusedAdam will not accumulate step for empty param groups, so we need to + # align the step across param groups. + param_group["step"] = int(step) ++ if "step" in param_group and param_group["step"] is None: ++ del param_group["step"] + + # Grad scaler state. + if self.grad_scaler: +@@ -1771,6 +1773,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + # separately via param_groups, not as part of the gradient buffer. + tensors[key] = LocalNonpersistentObject(tensors[key]) + continue ++ if key == 'step': ++ continue + assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( + tensors[key].shape, + gbuf_local_start, +diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py +index 863b5d55d..6c81e6e5f 100644 +--- a/megatron/core/parallel_state.py ++++ b/megatron/core/parallel_state.py +@@ -11,6 +11,7 @@ from typing import Callable, List, Optional + + import numpy as np + import torch ++import torch.distributed as dist + + from megatron.core.inference.symmetric_memory import SymmetricMemoryManager + +diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py +index 465e83f28..707162f47 100644 +--- a/megatron/core/pipeline_parallel/p2p_communication.py ++++ b/megatron/core/pipeline_parallel/p2p_communication.py +@@ -27,22 +27,22 @@ def _batched_p2p_ops( + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group ++ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, + ) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, + ) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group ++ torch.distributed.isend, tensor_send_next, next_pipeline_rank, + ) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, + ) + ops.append(recv_next_op) + if len(ops) > 0: +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index d316d23de..9bb6d2bd6 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -787,6 +787,9 @@ def topk_routing_with_score_function( + scores, topk, num_groups, group_topk, _compute_topk + ) + ++ from relax.utils.training.routing_replay import get_routing_replay_compute_topk ++ compute_topk = get_routing_replay_compute_topk(compute_topk) ++ + # Precision notes: + # - Logits are converted to fp32 for score functions. + # - All the intermediate calculations are in fp32. +diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py +index b675d33cd..0cf3e006a 100644 +--- a/megatron/core/transformer/moe/router.py ++++ b/megatron/core/transformer/moe/router.py +@@ -216,6 +216,9 @@ class TopKRouter(Router): + if self.config.moe_enable_routing_replay: + self.router_replay = RouterReplay() + ++ from relax.utils.training.routing_replay import register_routing_replay ++ register_routing_replay(self) ++ + def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. +diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py +index ba5018a94..79ed327de 100755 +--- a/megatron/core/transformer/multi_token_prediction.py ++++ b/megatron/core/transformer/multi_token_prediction.py +@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union + + import torch + from torch import Tensor ++import warnings + + from megatron.core import InferenceParams, parallel_state, tensor_parallel + from megatron.core.dist_checkpointing.mapping import ShardedStateDict +@@ -891,17 +892,19 @@ class MultiTokenPredictionLayer(MegatronModule): + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) +- position_ids, _ = roll_tensor( +- position_ids, +- shifts=-1, +- dims=-1, +- cp_group=self.cp_group, +- packed_seq_params=packed_seq_params, +- ) ++ if position_ids is not None: ++ position_ids, _ = roll_tensor( ++ position_ids, ++ shifts=-1, ++ dims=-1, ++ cp_group=self.cp_group, ++ packed_seq_params=packed_seq_params, ++ ) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) ++ decoder_input = decoder_input.detach() + +- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) ++ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) + + return input_ids, position_ids, decoder_input, hidden_states + +@@ -1059,6 +1062,51 @@ class MultiTokenPredictionLayer(MegatronModule): + return hidden_states + + def _checkpointed_forward(self, forward_func, *args, **kwargs): ++ """Wrap `forward_func` with activation checkpointing while only passing tensors. ++ ++ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so ++ that checkpoint implementations never receive them directly, avoiding save_for_backward ++ issues with non-tensor inputs. ++ """ ++ ++ # TODO(jiajun): Is there any better implementation here? ++ positional_specs = [] ++ kw_specs = [] ++ tensor_args: List[torch.Tensor] = [] ++ ++ for arg in args: ++ if torch.is_tensor(arg): ++ positional_specs.append(('tensor', len(tensor_args))) ++ tensor_args.append(arg) ++ else: ++ positional_specs.append(('const', arg)) ++ ++ for key, value in kwargs.items(): ++ if torch.is_tensor(value): ++ kw_specs.append((key, ('tensor', len(tensor_args)))) ++ tensor_args.append(value) ++ else: ++ kw_specs.append((key, ('const', value))) ++ ++ def run(*flat_tensor_args): ++ rebuilt_args = [] ++ for spec_type, payload in positional_specs: ++ if spec_type == 'tensor': ++ rebuilt_args.append(flat_tensor_args[payload]) ++ else: ++ rebuilt_args.append(payload) ++ ++ rebuilt_kwargs = {} ++ for key, (spec_type, payload) in kw_specs: ++ if spec_type == 'tensor': ++ rebuilt_kwargs[key] = flat_tensor_args[payload] ++ else: ++ rebuilt_kwargs[key] = payload ++ ++ return forward_func(*rebuilt_args, **rebuilt_kwargs) ++ ++ tensor_args_tuple = tuple(tensor_args) ++ + def checkpoint_handler(): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: +@@ -1069,12 +1117,11 @@ class MultiTokenPredictionLayer(MegatronModule): + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), +- *args, +- **kwargs, ++ *tensor_args_tuple, + ) + else: + return tensor_parallel.checkpoint( +- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() ++ run, self.config.distribute_saved_activations, *tensor_args_tuple + ) + + if self.config.recompute_method == 'uniform': +diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py +index cac634ff9..e5e63197a 100644 +--- a/megatron/core/transformer/transformer_config.py ++++ b/megatron/core/transformer/transformer_config.py +@@ -244,6 +244,9 @@ class TransformerConfig(ModelParallelConfig): + attention_output_gate: bool = False + """Whether to apply output gate to the attention layers.""" + ++ post_self_attn_layernorm: bool = False ++ post_mlp_layernorm: bool = False ++ + test_mode: bool = False + """Whether to run real-time tests.""" + +diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py +index ee2054511..2ba4d0664 100644 +--- a/megatron/core/transformer/transformer_layer.py ++++ b/megatron/core/transformer/transformer_layer.py +@@ -245,6 +245,7 @@ class TransformerLayerSubmodules: + self_attention_hyper_connection: Union[ModuleSpec, type] = IdentityOp + self_attention: Union[ModuleSpec, type] = IdentityOp + self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + + pre_cross_attn_layernorm: LayerNormBuilder = IdentityOp + cross_attention_hyper_connection: Union[ModuleSpec, type] = IdentityOp +@@ -255,6 +256,7 @@ class TransformerLayerSubmodules: + mlp_hyper_connection: Union[ModuleSpec, type] = IdentityOp + mlp: Union[ModuleSpec, type] = IdentityOp + mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + + # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method + sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) +@@ -352,6 +354,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + ++ self.post_self_attn_layernorm = build_module( ++ submodules.post_self_attn_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon, ++ ) ++ + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = submodules.pre_cross_attn_layernorm( + config=self.config, +@@ -418,6 +427,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + + self.is_moe_layer = isinstance(self.mlp, MoELayer) + ++ self.post_mlp_layernorm = build_module( ++ submodules.post_mlp_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon ++ ) ++ + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False +@@ -638,6 +654,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + attention_output_with_bias[0] + ) + ++ attention_output, attention_output_bias = attention_output_with_bias ++ attention_output = self.post_self_attn_layernorm(attention_output) ++ attention_output_with_bias = (attention_output, attention_output_bias) ++ + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + nvtx_range_push(suffix="self_attn_bda") +@@ -823,6 +843,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + self._set_fc2_residual(residual) + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output, padding_mask=padding_mask) + ++ mlp_output, mlp_output_bias = mlp_output_with_bias ++ mlp_output = self.post_mlp_layernorm(mlp_output) ++ mlp_output_with_bias = (mlp_output, mlp_output_bias) ++ + nvtx_range_pop(suffix="mlp") + + if ( +diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py +index 62f6e4426..c25969b3b 100644 +--- a/megatron/training/arguments.py ++++ b/megatron/training/arguments.py +@@ -1992,6 +1992,9 @@ def core_transformer_config_from_args(args, config_class=None): + + kw_args['inference_sampling_seed'] = args.seed + ++ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm ++ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm ++ + # handle quantization config + # NOTE: Kitchen arguments are only added to the namespace when + # Kitchen library is available. +@@ -2475,7 +2478,7 @@ def _add_network_size_args(parser): + '--position-embedding-type', + type=str, + default='learned_absolute', +- choices=['learned_absolute', 'rope', 'mrope', 'relative', 'none'], ++ choices=['learned_absolute', 'rope', 'yarn', 'mrope', 'relative', 'none'], + help='Position embedding type.', + ) + group.add_argument( +diff --git a/megatron/training/training.py b/megatron/training/training.py +index a0817e834..7cd094dc5 100644 +--- a/megatron/training/training.py ++++ b/megatron/training/training.py +@@ -222,7 +222,9 @@ from megatron.training.utils import ( + try: + from torch_memory_saver import torch_memory_saver + +- torch_memory_saver.hook_mode = "torch" ++ # NOTE(wuhuan): keep the default hook mode; forcing "torch" triggers ++ # 'torch.AcceleratorError: CUDA error: invalid argument' on weight updates. ++ # torch_memory_saver.hook_mode = "torch" + HAVE_TORCH_MEMORY_SAVER = True + except ImportError: + HAVE_TORCH_MEMORY_SAVER = False +diff --git a/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py b/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py +index 7fcf295e..7ac11345 100644 +--- a/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py ++++ b/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py +@@ -404,15 +404,6 @@ class Qwen35VLMoEBridge(MegatronModelBridge): + k="mtp.layers.*.self_attn.k_proj.weight", + v="mtp.layers.*.self_attn.v_proj.weight", + ), +- GatedMLPMapping( +- megatron_param="language_model.mtp.layers.*.mtp_model_layer.mlp.experts.linear_fc1.weight*", +- gate="mtp.layers.*.mlp.experts.*.gate_proj.weight", +- up="mtp.layers.*.mlp.experts.*.up_proj.weight", +- ), +- AutoMapping( +- megatron_param="language_model.mtp.layers.*.mtp_model_layer.mlp.experts.linear_fc2.weight*", +- hf_param="mtp.layers.*.mlp.experts.*.down_proj.weight", +- ), + GatedMLPMapping( + megatron_param="language_model.mtp.layers.*.mtp_model_layer.mlp.shared_experts.linear_fc1.weight", + gate="mtp.layers.*.mlp.shared_expert.gate_proj.weight", +@@ -429,6 +420,43 @@ class Qwen35VLMoEBridge(MegatronModelBridge): + ] + ) + ++ # Detect MTP MoE expert weight format: Qwen3.5 stores per-expert ++ # (mtp.layers.0.mlp.experts.{i}.gate_proj.weight), Qwen3.6 stores packed ++ # (mtp.layers.0.mlp.experts.gate_up_proj). Same architecture string, ++ # different storage — must inspect HF keys. ++ mtp_experts_packed = False ++ if hasattr(self.hf_pretrained, "state") and hasattr(self.hf_pretrained.state, "source"): ++ hf_keys = set(self.hf_pretrained.state.source.get_all_keys()) ++ if "mtp.layers.0.mlp.experts.gate_up_proj" in hf_keys: ++ mtp_experts_packed = True ++ ++ if mtp_experts_packed: ++ # Qwen3.6: packed format (same as main decoder) ++ mapping_list.extend([ ++ FusedGatedExpertMapping( ++ megatron_param="language_model.mtp.layers.*.mtp_model_layer.mlp.experts.linear_fc1.weight*", ++ hf_param="mtp.layers.*.mlp.experts.gate_up_proj", ++ ), ++ FusedExpertMapping( ++ megatron_param="language_model.mtp.layers.*.mtp_model_layer.mlp.experts.linear_fc2.weight*", ++ hf_param="mtp.layers.*.mlp.experts.down_proj", ++ transpose_on_export=True, ++ ), ++ ]) ++ else: ++ # Qwen3.5: per-expert format (current behavior) ++ mapping_list.extend([ ++ GatedMLPMapping( ++ megatron_param="language_model.mtp.layers.*.mtp_model_layer.mlp.experts.linear_fc1.weight*", ++ gate="mtp.layers.*.mlp.experts.*.gate_proj.weight", ++ up="mtp.layers.*.mlp.experts.*.up_proj.weight", ++ ), ++ AutoMapping( ++ megatron_param="language_model.mtp.layers.*.mtp_model_layer.mlp.experts.linear_fc2.weight*", ++ hf_param="mtp.layers.*.mlp.experts.*.down_proj.weight", ++ ), ++ ]) ++ + return MegatronMappingRegistry(*mapping_list) + + diff --git a/docker/run-rocm-bind.sh b/docker/run-rocm-bind.sh index af793de6..bfff0cfc 100755 --- a/docker/run-rocm-bind.sh +++ b/docker/run-rocm-bind.sh @@ -5,7 +5,7 @@ set -euo pipefail SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" -IMAGE="${IMAGE:-relax:rocm-relax-smoke}" +IMAGE="${IMAGE:-relax:rocm-gfx942}" CONTAINER_NAME="${CONTAINER_NAME:-relax_rocm_bind}" RELAX_DIR="${RELAX_DIR:-${REPO_ROOT}}" MODEL_DIR="${MODEL_DIR:-/mnt/dcgpuval/models}" diff --git a/docs/draft/dynamic-context-parallel.md b/docs/draft/dynamic-context-parallel.md new file mode 100644 index 00000000..4c0a43d7 --- /dev/null +++ b/docs/draft/dynamic-context-parallel.md @@ -0,0 +1,334 @@ +# Relax 接入 Dynamic Context Parallel 改造方案 + +> **状态**:Draft / 设计阶段 +> **背景**:verl 在 PR [#5057](https://github.com/volcengine/verl/pull/5057)(commit `7e9a07c4`,2026-03-31)落地了 dynamic CP,能让每个 micro-batch 按当前最大 seq_len 自适应选 cp_size,对 RL 训练的长短样本混合场景明显友好。本文档给出 Relax 端的改造方案。 +> **参考**:verl 实现详见 `/root/repos/verl/verl_cp.md` §6。 + +--- + +## 0. 前置事实(决定方案形态) + +| 维度 | verl 现状 | Relax 现状 | 影响 | +|---|---|---|---| +| 引擎封装 | `MegatronEngine` 类 | 无类,直接 `mpu.*`(散布在 `cp_utils.py`/`ppo_utils.py`/`loss.py`/`data.py`/`actor.py`/`model.py`) | verl 的 "DP=1 伪装" 做不到方法重写,需另想 | +| Batch 类型 | TensorDict + `non_tensor_data` | `dict[str, list[Tensor]]`(`utils/types.py:194`) | `local_cp_size` 必须显式作为字段流转 | +| Megatron-LM | 上游 PR #3405 后的 dev 分支 | Bridge pin `2faedbf6...` 拉的是 **main 分支** `f4a071039`,**不含** `dynamic_context_parallel` 形参与 `get_dynamic_data_context_parallel_groups` API(已实测);PR #3405 的 commit `cde56a469` 只在 dev 分支 | **必须先解决依赖**:要么 bump bridge pin → dev,要么打补丁把 #3405 + #2000 cherry-pick 进当前 main | +| 数据切分 | `preprocess_thd_engine` 一处 | `get_batch`(`data.py:106-335`)+ `cp_utils.py` 一组帮助函数 | 改面更广,但都在两个文件里 | +| Token budget | `max_token_len * sp_size` | `max_tokens_per_gpu * cp_size`(`data.py:533`,写死) | 必须改成 per-microbatch 的 effective cp_size | +| 损失尺度 | `loss * num_micro_batch`(一处) | `loss.py:1080-1087` 多处含 `mpu.get_*_world_size(with_context_parallel=True)` 因子 | 需引入 per-microbatch normalizer | +| VLM | thd + bridge VL 透传,align `tp*cp*2` | 同(`data.py:177-212`、`cp_utils.py:9-29`),同样 hardcode `tp*cp*2` | 同样需要 per-microbatch align | + +--- + +## 1. 设计原则 + +**支点**:`local_cp_size` 通过 batch dict 一路下传到 `get_batch`、`forward_step`、loss、postprocess。所有 `mpu.get_context_parallel_*()` 调用改为接受可选的 `cp_group`/`cp_size` 参数,当外部传入则用之,否则退回全局 mpu(保持向后兼容)。 + +**两段式落地**: +- **Phase 1 MVP**:cp_size 在 DP 内**统一**(取 max),按 micro-batch 自适应。无 sub-DP 路由、不动数据 partition。一个开关上线即可享受"短样本不付 CP 通信代价"的收益。 +- **Phase 2 verl 等价**:引入 sub-DP 路由 + 异构 cp_size。需要 `mpu.get_dynamic_data_context_parallel_groups`、`dynamic_cp_split_batch`、`dynamic_cp_merge_output`。 + +> 强烈建议 **Phase 1 先上线、跑稳、再做 Phase 2**。Phase 1 的总工作量 ≈ Phase 2 的 1/3,且 Phase 2 多出来的 sub-DP 路由对 RL 训练里的 advantage/log-prob 全局聚合有破坏性影响(verl 自己也留了 TODO 没完成)。 + +--- + +## 2. Phase 0:依赖准备 + +### 2.1 决定 Megatron-LM 来源(二选一) + +| 选项 | 做法 | 收益 | 代价 | +|---|---|---|---| +| A. Bump Bridge pin | 把 `docker/Dockerfile:78` 的 `MEGATRON_BRIDGE_COMMIT` 升到含 mcore-dev 的版本,并把 `switch_mcore.sh` 切到 `dev` | 自带 dynamic CP + 持续维护 | dev 分支稳定性差,需要回归全套现有训练任务 | +| B. 现有 main + 补丁 | 在 `docker/patch/megatron/` 下加一个补丁,cherry-pick `cde56a469` (#3405) 和 `2d6e946ba` (#2000 Dynamic CP part 2) | 影响面小,可控 | 维护补丁 conflict 风险 | + +**推荐 B**:cherry-pick 两个 commit 落到 `docker/patch/latest/megatron-dynamic-cp.patch`。改造期间任何升级 megatron 的人都看到 patch,风险显式化。 + +### 2.2 验证脚本 + +在 `relax/backends/megatron/initialize.py` 加运行时探测: + +```python +import inspect +HAS_DYNAMIC_CP = "dynamic_context_parallel" in inspect.signature(mpu.initialize_model_parallel).parameters +``` + +供后续条件性启用。 + +--- + +## 3. Phase 1(MVP)改动清单 + +### 3.1 配置(args + validate) + +```python +# relax/utils/arguments.py(新增 group "Dynamic CP") +parser.add_argument("--dynamic-context-parallel", action="store_true", + help="Enable per-microbatch adaptive CP size (requires Megatron-LM PR #3405).") +parser.add_argument("--max-seqlen-per-dp-cp-rank", type=int, default=None, + help="Upper bound of tokens per DP×CP rank. Required when --dynamic-context-parallel.") +``` + +```python +# relax/backends/megatron/arguments.py: validate_args +if args.dynamic_context_parallel: + assert HAS_DYNAMIC_CP, "Megatron-LM lacks PR #3405; bump pin or apply patch." + assert args.max_seqlen_per_dp_cp_rank is not None + assert args.context_parallel_size >= 1 + dp_size = compute_dp_size(...) + assert dp_size * args.max_seqlen_per_dp_cp_rank >= args.max_response_len + args.max_prompt_len +``` + +### 3.2 mpu 初始化(`initialize.py:39-55`) + +```python +extra = {} +if args.dynamic_context_parallel and HAS_DYNAMIC_CP: + extra["dynamic_context_parallel"] = True +mpu.initialize_model_parallel(..., **extra) +``` + +### 3.3 Bridge config 反向欺骗(`model_provider.py:175-215`,**仅 bridge 模式**) + +```python +if args.dynamic_context_parallel: + overrides["max_seqlen_per_dp_cp_rank"] = args.max_seqlen_per_dp_cp_rank + overrides["dynamic_context_parallel"] = False # 同 verl 注释里的 "bad coupling" 绕道 + overrides["context_parallel_size"] = mpu.get_data_parallel_world_size() # 让 bridge 误以为 cp=DP +``` + +raw 模式下不做这件事(直接用 `core_transformer_config_from_args(args)` 即可,但要在那里同样防止 `args.context_parallel_size` 被 transformer config 误读 —— 或者干脆 raw 模式 Phase 1 不支持,给出清晰报错)。 + +### 3.4 Per-microbatch CP 决策 + +新增 `relax/backends/megatron/dynamic_cp.py`: + +```python +def decide_local_cp_size(samples_in_microbatch, max_seqlen_per_dp_cp_rank, dp_size): + """同 verl: ceil(max_seqlen / cap), 向上取 2 的幂, clamp ≤ dp_size.""" + max_seq = max(len(s["input_ids"]) for s in samples_in_microbatch) + n = math.ceil(max_seq / max_seqlen_per_dp_cp_rank) + n = max(1, 1 << (n - 1).bit_length()) + return min(n, dp_size) + + +def gather_max_local_cp(local_cp: int, dp_group) -> int: + """Phase 1 MVP: DP 间统一取 MAX。""" + t = torch.tensor([local_cp], device="cuda") + torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.MAX, group=dp_group) + return int(t.item()) +``` + +### 3.5 数据流(`data.py`) + +#### a) 注入决策点(`get_data_iterator`,`data.py:449-573`) + +在 `get_seqlen_balanced_partitions` 之后、构造 `DataIterator` 之前: + +```python +if args.dynamic_context_parallel: + per_mb_cp_size = [] + for mb_samples in partitioned_microbatches: + local = decide_local_cp_size(mb_samples, args.max_seqlen_per_dp_cp_rank, + mpu.get_data_parallel_world_size(with_context_parallel=False)) + per_mb_cp_size.append(gather_max_local_cp(local, mpu.get_data_parallel_group())) +else: + per_mb_cp_size = [args.context_parallel_size] * len(partitioned_microbatches) + +# 把 per_mb_cp_size 挂到 DataIterator 上 +``` + +#### b) Token budget(`data.py:533`) + +```python +# 改前: max_tokens_per_gpu * cp_size +# 改后: +budget_cp = args.context_parallel_size if not args.dynamic_context_parallel \ + else 1 # MVP: 不预知,就按最坏的 cp=1 打包,让 dynamic_cp 决策时再涨 cp +get_minimum_num_micro_batch_size(samples[start:end], + args.max_tokens_per_gpu * budget_cp) +``` + +> 这里有个微妙取舍:dynamic CP 的 "动态" 是在 batch 已经分好后的 second pass。打包阶段用 `cp=1` 的 budget 意味着每 micro-batch 都按 "塞满单卡" 打 → 长样本最多触发 cp=8,能跑通;但短样本浪费空间。Phase 2 可以重排打包。 + +#### c) `get_batch`(`data.py:106-335`) + +签名加可选 `local_cp_size: Optional[int] = None`: + +```python +def get_batch(args, samples, ..., local_cp_size: Optional[int] = None): + if local_cp_size is not None: + cp_size = local_cp_size + cp_group = mpu.get_dynamic_data_context_parallel_groups(group_size=local_cp_size) + cp_rank = torch.distributed.get_rank(cp_group) + else: + cp_size = mpu.get_context_parallel_world_size() + cp_group = mpu.get_context_parallel_group() + cp_rank = mpu.get_context_parallel_rank() + ... + # 所有用 cp_size/cp_rank 的地方改用上面的本地变量 + # PackedSeqParams 上挂 cp_group, local_cp_size + packed_seq_params = PackedSeqParams( + ..., cp_group=cp_group, local_cp_size=local_cp_size, + ) +``` + +`slice_with_cp`(`cp_utils.py:210-251`)同样加 `cp_size, cp_rank` 参数。 + +#### d) `DataIterator.get_next` + +返回的 dict 多带一个键 `"local_cp_size"`,由 `get_batch` 透传到 `forward_step`。 + +### 3.6 Forward & loss + +#### a) `forward_step`(`model.py:222-303`、`399-489`) + +```python +local_cp_size = batch.get("local_cp_size") # None 表示静态 CP +batch_processed = get_batch(..., local_cp_size=local_cp_size) +output_tensor = model(...) +return output_tensor, partial(postprocess_fn, ..., local_cp_size=local_cp_size, + cp_group=batch_processed["cp_group"]) +``` + +#### b) `cp_utils.py` 全部公共 helper + +涉及函数:`get_logits_and_tokens_offset_with_cp`、`all_gather_with_cp`、`get_sum_of_sample_mean`、`slice_log_prob_with_cp`、`maybe_padded_total_lengths`。 + +签名加 `cp_size: Optional[int] = None, cp_group: Optional[ProcessGroup] = None`。内部 `mpu.get_context_parallel_*()` 改为: + +```python +cp_size = cp_size or mpu.get_context_parallel_world_size() +cp_group = cp_group or mpu.get_context_parallel_group() +cp_rank = torch.distributed.get_rank(cp_group) if cp_group else mpu.get_context_parallel_rank() +``` + +调用方(`loss.py`、`ppo_utils.py:298-336/402-423/463-515`、`actor.py:1047`、`advantages.py:159`)补传两个参数。 + +#### c) `loss.py:1080-1087` 损失尺度 + +```python +# 改前: +# loss = loss * num_microbatches / global_batch_size * mpu.get_data_parallel_world_size(with_context_parallel=True) +# loss = loss * mpu.get_context_parallel_world_size() + +# 改后: +effective_cp = local_cp_size if local_cp_size is not None else mpu.get_context_parallel_world_size() +# DP 维度 Phase 1 MVP 仍按全局 DP(cp 同 DP 内统一),不变 +dp_x_cp = mpu.get_data_parallel_world_size(with_context_parallel=False) * effective_cp +loss = loss * num_microbatches / global_batch_size * dp_x_cp # sample-mean +loss = loss * effective_cp # per-token +``` + +> ⚠️ **Phase 1 关键约束**:MVP 同一 micro-batch 内所有 DP rank 用相同 cp,所以 `mpu.get_data_parallel_world_size(with_context_parallel=True)` 在 Phase 1 ≡ `dp * effective_cp`,等价。Phase 2 引入异构后才会破。 + +### 3.7 VLM(`data.py:177-212`,`cp_utils.py:9-29`) + +把硬编码的 `tp * cp * 2` 替换为 `tp * effective_cp * 2`,其中 `effective_cp = local_cp_size or args.context_parallel_size`。`vlm_packed_seq_params` 同样要带 `cp_group`,让 Bridge 内部走对的 ring-CP 通信。 + +> 风险:Bridge 自己内部如何调度 dynamic CP 取决于 PR #3405 + Bridge 自身。建议 **Phase 1 先在纯文本上线,VLM + dynamic CP 单独作为 Phase 1.5 验证**。 + +### 3.8 不动的部分 + +- `actor.py` 里的 `data_system_client` 数据接收:rollout 阶段不感知 CP,CP 只是 trainer 内部事。 +- `RolloutBatch` 类型定义:保持 `dict[str, list[Tensor]]`,新键加在 micro-batch 那一层。 +- 所有 raw 模式相关代码:Phase 1 报错 "raw mode 暂不支持 dynamic CP"。 + +--- + +## 4. Phase 2(verl 完整等价)追加改动 + +待 Phase 1 稳定后再做。 + +### 4.1 引入 sub-DP 路由 + +新增 `dynamic_cp_split_batch` —— 在 `data.py` 的 `get_data_iterator` 末尾对每个 micro-batch: + +```python +if local_cp_size < dp_size: + local_dp_rank = dp_rank // local_cp_size + local_dp_size = dp_size // local_cp_size + # 把 partitioned_microbatches[i] 进一步切成 local_dp_size 份 + # 每个 sub-DP 拿自己那份 +``` + +需要的 mpu 新 API:`get_dynamic_data_context_parallel_groups(group_size=local_cp_size)`(PR #3405 已提供)。 + +### 4.2 引入 `dynamic_cp_merge_output` + +postprocess 阶段对需要跨 sub-DP all_gather 的输出(log_probs、entropy、advantages 等),同 verl 用 `all_gather_object` 在 `dp_group` 内按 stride 重组。 + +### 4.3 DP "伪装" + +由于 Relax 没有 `engine.get_data_parallel_size()` 方法,建议反过来:在 `compute_dp_size` 旁加一个 `get_logical_dp_size(args)`,dynamic CP 时返回 1。所有 dynamic-batching/loss 尺度计算改用 `get_logical_dp_size`,物理 DP 通信仍用 `mpu`。 + +### 4.4 损失跨 sub-DP 聚合 + +verl 在这里留了 TODO,Relax 不能照抄。建议至少在 `loss.py` 增加一次跨 sub-DP 的 weighted average(按 sub-group 的样本数权重),否则 advantage normalize / KL 估计会 biased。 + +### 4.5 数据 scheduler 长度感知 + +verl post-merge TODO 里要做的事 —— 在 `get_seqlen_balanced_partitions` 之前按长度排序并按桶分箱,同桶 micro-batch 用相同 cp。能让 DP 内 `MAX(local_cp)` 趋近 `MEAN(local_cp)`。 + +--- + +## 5. 测试方案 + +| 层级 | 测试 | 通过标准 | +|---|---|---| +| 单元 | `decide_local_cp_size` / `gather_max_local_cp` | 输入构造的 batch → 期望 cp_size | +| 单元 | `get_batch(local_cp_size=2)` vs `get_batch()` 在 cp=2 静态时 | bit-exact | +| 集成 | 8GPU 单机 SFT,TP=4 PP=1 CP=1,开/关 dynamic | loss 曲线在数值容差内对齐 | +| 集成 | 8GPU GRPO,长样本 + 短样本混合,dynamic vs 静态 cp=4 | reward 曲线一致,throughput dynamic 更高 | +| 回归 | 现有所有 e2e 训练脚本(`scripts/training/`),关闭 dynamic CP | 必须 bit-exact | +| VLM(Phase 1.5) | Qwen3-VL 8GPU | loss 对齐 | + +--- + +## 6. 风险与决策点 + +| # | 风险 | 缓解 | +|---|---|---| +| 1 | Megatron pin 升级或 patch 维护负担 | 倾向 patch 路线,写在 `docker/patch/latest/`,CI 自动 apply | +| 2 | Phase 1 损失尺度公式在 cp 同 DP 内统一时等价,但需要严谨证明 | 在 PR 描述里写出代数推导 + 单元测试覆盖两种 mode | +| 3 | VLM 的 Bridge 内部 CP 是否随 `local_cp_size` 自动适应未知 | Phase 1.5 单独验证;最坏情况 VLM 不支持 dynamic CP(同 verl bshd 的处理) | +| 4 | raw 模式 Phase 1 不支持 | 显式报错引导用户切到 bridge 模式 | +| 5 | `cp_utils.py` 全部 helper 改签名是 breaking 修改 | 给所有参数加 `Optional` 默认值,保证既有调用方一行不动 | +| 6 | Phase 2 的 sub-DP 路由对 RL 全局统计(advantage normalize)是破坏性的 | Phase 2 启动前,先盘清 `relax/components/advantages.py` 和 `relax/utils/training/ppo_utils.py` 里所有跨 DP 聚合点,逐一决定是否需要补 cross-sub-DP 聚合 | + +--- + +## 7. 工作量预估(人日) + +| 阶段 | 模块 | 估算 | +|---|---|---| +| Phase 0 | Megatron patch + 探测 | 1-2 | +| Phase 1 | 配置 + mpu/bridge + dynamic_cp.py | 2 | +| Phase 1 | data.py / cp_utils.py / forward_step / loss.py 改造 | 4-5 | +| Phase 1 | 单元 + 集成测试 + 回归 | 3-4 | +| **Phase 1 小计** | | **~10-13 人日** | +| Phase 1.5 | VLM 验证 + 修复 | 3-5 | +| Phase 2 | sub-DP 路由 + merge + DP 伪装 + 损失聚合 | 8-12 | +| Phase 2 | 测试 + 回归 | 5 | +| **总计** | | **~25-35 人日** | + +--- + +## 8. 落地起始的最小补丁清单(P1 第一周) + +1. `docker/patch/latest/megatron-dynamic-cp.patch` —— cherry-pick #3405 + #2000 +2. `relax/utils/arguments.py` —— 加 2 个 flag +3. `relax/backends/megatron/arguments.py` —— validate +4. `relax/backends/megatron/initialize.py` —— mpu init + `HAS_DYNAMIC_CP` +5. `relax/backends/megatron/model_provider.py` —— bridge 反向欺骗 +6. `relax/backends/megatron/dynamic_cp.py`(新文件) —— `decide_local_cp_size` / `gather_max_local_cp` + +走通这 6 处 → 已经可以 `dynamic_context_parallel=True` 跑起来(虽然 cp_size 还没真在 batch 里变化),后续再分别接通 data → forward → loss 三段。 + +--- + +## 9. 参考 + +- verl PR #5057: +- verl 实现详解:`/root/repos/verl/verl_cp.md` §6(来源 / mpu+bridge 双重欺骗 / split+merge / 限制汇总) +- Megatron-LM PR #3405:(dev 分支已合,main 未合) +- Megatron-LM PR #2000:Dynamic CP part 2(提供 `get_dynamic_data_context_parallel_groups`) diff --git a/docs/en/examples/deepeyes.md b/docs/en/examples/deepeyes.md index fe195479..ce065c0b 100644 --- a/docs/en/examples/deepeyes.md +++ b/docs/en/examples/deepeyes.md @@ -49,25 +49,12 @@ The HF Image dict format (`{"bytes": ...}`) is natively supported by Relax's ima ### Download the Model ```bash -hf download Qwen/Qwen3-VL-4B-Instruct \ - --local-dir /root/Qwen3-VL-4B-Instruct +hf download Qwen/Qwen3-VL-30B-A3B-Thinking \ + --local-dir /root/Qwen3-VL-30B-A3B-Thinking ``` -For the full-scale configuration, use `Qwen/Qwen3-VL-30B-A3B-Thinking`. - ## Quick Start -### 4B Model (8 GPUs) - -```bash -export MODEL_DIR=/root -export DATA_DIR=/root -export SAVE_DIR=/root/save - -cd /root/Relax -bash examples/deepeyes/run_deepeyes_4b.sh -``` - ### 30B-A3B Model (8 GPUs, MoE) The full-scale configuration requires a judge model for reward scoring: @@ -91,7 +78,7 @@ bash examples/deepeyes/run_deepeyes.sh ```bash WORKING_DIR="./" RAY_ADDRESS=:6379 \ MODEL_DIR=/root DATA_DIR=/root SAVE_DIR=/root/save \ - bash -x scripts/entrypoint/ray-job.sh examples/deepeyes/run_deepeyes_4b.sh + bash -x scripts/entrypoint/ray-job.sh examples/deepeyes/run_deepeyes.sh ``` ## Architecture @@ -101,7 +88,6 @@ WORKING_DIR="./" RAY_ADDRESS=:6379 \ ``` examples/deepeyes/ ├── run_deepeyes.sh # Launch script (Qwen3-VL-30B-A3B, full config) -├── run_deepeyes_4b.sh # Launch script (Qwen3-VL-4B, lightweight) ├── deepeyes_config.yaml # Task config (max_turns, env path) ├── rollout.py # Multi-turn rollout logic ├── env_deepeyes.py # DeepEyes tool-use environment diff --git a/docs/en/guide/configuration.md b/docs/en/guide/configuration.md index 0544f00d..8da5f1fa 100644 --- a/docs/en/guide/configuration.md +++ b/docs/en/guide/configuration.md @@ -88,6 +88,9 @@ For common configuration usage and examples, see the [Quick Start Guide](./quick | `--rollout-seed` | int | 42 | Random seed for Rollout, used for shuffling prompts and random sampling | | `--use-streaming-dataset` | flag | False | Use streaming dataset to save memory | | `--streaming-buffer-size` | int | 10000 | Buffer size for streaming dataset | +| `--prefetch-chunk-size` | int | 32 | Number of samples to dispatch to the thread-pool in each prefetch round. Larger values increase throughput but also memory pressure. Only effective when `--use-streaming-dataset` is set and the dataset contains multimodal data | +| `--prefetch-max-cached` | int | 256 | Maximum number of pre-loaded samples kept in the prefetch cache. When the cache is full the background prefetch thread pauses until consumers free space. Set to 0 to disable prefetching. Only effective when `--use-streaming-dataset` is set and the dataset contains multimodal data | +| `--prefetch-num-workers` | int | 1 | Number of parallel worker threads inside the prefetch buffer for I/O-bound media decoding (video/image). Set to 1 to serialise all decoding (safest for FFmpeg which is not fully thread-safe). Higher values increase parallelism but may trigger EAGAIN errors on some platforms. Only effective when prefetching is enabled | | `--data-source-path` | str | `relax.engine.rollout.data_source.RolloutDataSourceWithBuffer` | Rollout data source class path | | `--start-rollout-id` | int | None | Starting Rollout step. If not set, attempts to read from checkpoint specified by `--load` | @@ -168,6 +171,16 @@ For more parameters, refer to SGLang official documentation. | Parameter | Type | Default | Description | |-----------|------|---------|-------------| | `--sglang-mem-fraction-static` | float | - | SGLang static memory allocation ratio | +| `--sglang-profile` | flag | False | Enable torch profiling on SGLang engines during rollout. Profile traces will be saved per rollout step | +| `--sglang-profile-steps` | int (list) | None | List of absolute rollout step IDs (0-indexed) at which to enable SGLang profiling. Takes precedence over `--sglang-profile-step-start/end`. Example: `--sglang-profile-steps 3 10 50` | +| `--sglang-profile-step-start` | int | None | Start of the rollout step range for SGLang profiling (**inclusive**, 0-indexed). Used with `--sglang-profile-step-end` to specify a contiguous range. Ignored if `--sglang-profile-steps` is set | +| `--sglang-profile-step-end` | int | None | End of the rollout step range for SGLang profiling (**inclusive**, 0-indexed). Used with `--sglang-profile-step-start` to specify a contiguous range. Ignored if `--sglang-profile-steps` is set. E.g. start=2, end=4 profiles steps 2, 3, 4 | +| `--sglang-profile-output-dir` | str | None | Output directory for SGLang profile traces. Defaults to `traces//sglang_trace` | +| `--sglang-profile-num-steps` | int | 3 | Number of SGLang forward steps to profile per rollout. -1 profiles the entire rollout step until `stop_profile` is called | +| `--sglang-profile-activities` | str (list) | ["CPU", "GPU"] | Activities to profile (e.g., `CPU GPU`) | +| `--sglang-profile-by-stage` | flag | False | Profile by stage (prefill/decode) separately | +| `--sglang-profile-with-stack` | flag | False | Record call stack in profile traces | +| `--sglang-profile-record-shapes` | flag | False | Record tensor shapes in profile traces | ### Custom Rollout Functions @@ -498,7 +511,9 @@ For autoscaler YAML configuration details, see [`relax/utils/autoscaler/autoscal --- -## Debug Parameters +## Debug & Profiling Parameters + +### Debug | Parameter | Type | Default | Description | |-----------|------|---------|-------------| @@ -510,14 +525,33 @@ For autoscaler YAML configuration details, see [`relax/utils/autoscaler/autoscal | `--save-debug-train-data` | str | None | Save training data. Path supports `{rollout_id}` placeholder | | `--dump-details` | str | None | Export all training details for post-hoc analysis | | `--check-weight-update-equal` | flag | False | Check if weight updates are equal | -| `--memory-snapshot-dir` | str | . | Memory snapshot directory | -| `--memory-snapshot-num-steps` | int | None | Memory snapshot steps | +| `--enable-cuda-memory-check` | flag | False | Enable memory check around low-level NCCL communication calls. Logs available GPU memory before each collective and attaches memory info to exceptions on failure | + +### Training Performance Profiling + +These parameters control the PyTorch Profiler for training steps. Trace files are saved to `traces//train_trace/` by default. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `--use-pytorch-profiler` | flag | False | Enable PyTorch's built-in profiler to record CUDA kernels, CPU ops, and communication during training (from Megatron) | +| `--profile-step-start` | int | 10 | Step offset at which to start profiling (**inclusive**, from Megatron). Counts from 0 since the current training launch, not absolute rollout ID; resets on checkpoint resumption | +| `--profile-step-end` | int | 12 | Step offset at which to stop profiling (**inclusive**, from Megatron). Same counting semantics as above. E.g. start=10, end=12 profiles steps 10, 11, 12 (3 steps) | | `--profile-target` | str (list) | train_overall | Profiling targets: `train_overall`, `train_actor`, `train_log_probs` | | `--profile-with-stack` | flag | False | Record stack information in profiler traces | | `--profile-with-memory` | flag | False | Record memory information in profiler traces | | `--profile-with-flops` | flag | False | Estimate FLOPs in profiler traces | -| `--memory-recorder` | str | torch | Memory recorder: `torch`, `memray` | -| `--enable-cuda-memory-check` | flag | False | Enable memory check around low-level NCCL communication calls. Logs available GPU memory before each collective and attaches memory info to exceptions on failure | + +### GPU Memory Profiling + +These parameters control GPU memory snapshot collection for diagnosing memory leaks and OOM issues. Snapshot files can be viewed with PyTorch Memory Viz tools (`torch.cuda.memory._viz`). + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `--record-memory-history` | flag | False | Enable CUDA memory allocation history recording (from Megatron). Records call stacks and tensor info for each allocation/deallocation, and auto-dumps a snapshot on OOM | +| `--memory-snapshot-path` | str | snapshot.pickle | Memory snapshot filename (from Megatron) | +| `--memory-snapshot-dir` | str | None | Memory snapshot output directory. Defaults to `traces//memory_snapshot` | +| `--memory-snapshot-num-steps` | int | None | Proactively dump a memory snapshot after the specified number of steps (0-indexed, i.e., setting 3 means dump after step 2) | +| `--memory-recorder` | str | torch | Memory recorder backend: `torch` (PyTorch built-in), `memray` (requires `pip install memray`) | ### Network diff --git a/docs/en/guide/customize-training.md b/docs/en/guide/customize-training.md index d3b995f0..f6a535a8 100644 --- a/docs/en/guide/customize-training.md +++ b/docs/en/guide/customize-training.md @@ -61,7 +61,7 @@ After adding the file, source the corresponding model configuration in your trai #### 2. Megatron Bridge Model Adaptation -Relax uses [Megatron Bridge](https://github.com/redai-infra/megatron-bridge) for automatic HF ↔ Megatron weight conversion. If your model is not yet supported by Megatron Bridge, you need to add support on the Megatron Bridge side first — see its project documentation for details. +Relax uses [Megatron Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) for automatic HF ↔ Megatron weight conversion. If your model is not yet supported by Megatron Bridge, you need to add support on the Megatron Bridge side first — see its project documentation for details. ::: tip AI-Assisted Integration This project provides a Codewiz skill `model-integration` (located at `.codewiz/skills/model-integration/`), covering the complete integration workflow for Bridge / Raw / FSDP backends, weight converter specifications, TP sharding logic, and common pitfalls. Invoke it in Codewiz via `invoke skill model-integration` for step-by-step guidance. diff --git a/docs/en/guide/installation.md b/docs/en/guide/installation.md index 9ce463da..7242c81a 100644 --- a/docs/en/guide/installation.md +++ b/docs/en/guide/installation.md @@ -125,10 +125,17 @@ export MEGATRON="your megatron path" export PYTHONPATH=your_megatron_path:$PYTHONPATH ``` -Additionally, Relax depends on megatron bridge for weight conversion. If you need weight conversion, install it: +Additionally, Relax depends on [Megatron Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) for weight conversion. Follow the install steps in `docker/Dockerfile`: merge the Bridge sources with the Megatron-LM submodule into a single directory and add it to `PYTHONPATH`: ```bash -pip install git+https://github.com/redai-infra/megatron-bridge.git@relax/dev --no-build-isolation --no-deps +export MEGATRON_BRIDGE_COMMIT=2faedbf6fe3c422835a44b2b360cadcb2a116a54 +git clone https://github.com/NVIDIA-NeMo/Megatron-Bridge.git +cd Megatron-Bridge && git checkout ${MEGATRON_BRIDGE_COMMIT} && \ + git submodule update --init --recursive && ./scripts/switch_mcore.sh dev +mkdir -p /your/path/Megatron-LM +cp -r src/megatron /your/path/Megatron-LM/ +rsync -avP 3rdparty/Megatron-LM/megatron/ /your/path/Megatron-LM/megatron/ +export PYTHONPATH=/your/path/Megatron-LM:$PYTHONPATH ``` ## Next Steps diff --git a/docs/en/guide/performance-tuning.md b/docs/en/guide/performance-tuning.md index 2657222f..d7b121f7 100644 --- a/docs/en/guide/performance-tuning.md +++ b/docs/en/guide/performance-tuning.md @@ -4,29 +4,90 @@ A practical guide to maximizing training throughput in Relax. All parameters men --- -## Profiling Training Performance +## Profiling -Before tuning, identify the bottleneck. Relax integrates PyTorch Profiler to generate TensorBoard-compatible traces. +Before tuning, identify the bottleneck. Relax provides three complementary profiling tools that cover **inference engine**, **training backend**, and **GPU memory**. All trace files are saved under `traces//` by default, separated by subdirectory: -### Enabling the Profiler +| Tool | Target | Default Output Directory | Viewer | +|---|---|---|---| +| SGLang Profiling | CUDA kernel / operator analysis for rollout inference | `traces//sglang_trace/` | TensorBoard or `https://ui.perfetto.dev/` | +| Training Profiling | Operator analysis for Actor training / log-probs computation | `traces//train_trace/` | TensorBoard or `https://ui.perfetto.dev/` | +| Memory Profiling | GPU memory allocation history, OOM diagnosis | `traces//memory_snapshot/` | [PyTorch Memory Viz](https://pytorch.org/memory_viz) | -The profiler is controlled by `--profile-step-start` and `--profile-step-end` (Megatron native parameters) together with `--profile-target`: +### Trace File Naming + +- **Training traces** include `rank{global}_dp{dp}_tp{tp}_pp{pp}` in filenames, e.g. `train_overall_rank0_dp0_tp0_pp0.1713780123.pt.trace.json.gz` +- **Memory snapshots** also include rank tags, e.g. `memory_snapshot_time1713780123_rank0_dp0_tp0_pp0_snapshot.pickle` +- **SGLang traces** use `engine{i}` prefix to distinguish engine instances, e.g. `engine0-1713780123-TP-0.trace.json.gz` + +### 1. SGLang Inference Profiling + +Runs `torch.profiler` on all SGLang engines during rollout via the `/start_profile` and `/stop_profile` HTTP APIs. Does not interfere with training-side profiling. + +**Example usage** — profile every rollout step: ```bash python3 relax/entrypoints/train.py \ - --profile-target train_overall \ + --sglang-profile \ + --tb-experiment-name my-experiment \ + # ... other args +``` + +**Selective step range** — only profile steps 2, 3, 4 (start/end are both inclusive; recommended to avoid excessive trace files): + +```bash +python3 relax/entrypoints/train.py \ + --sglang-profile \ + --sglang-profile-step-start 2 \ + --sglang-profile-step-end 4 \ + --tb-experiment-name my-experiment \ + # ... other args +``` + +You can also use `--sglang-profile-steps` to specify a non-contiguous list (takes precedence over start/end): + +```bash +--sglang-profile-steps 2 5 10 +``` + +All step parameters use **absolute rollout IDs** (0-indexed), i.e., step 0, step 1, ... regardless of `--start-rollout-id`. + +**Advanced parameters**: + +| Parameter | Default | Description | +|---|---|---| +| `--sglang-profile-step-start` | None | Start of the profiling rollout step range (**inclusive**, 0-indexed) | +| `--sglang-profile-step-end` | None | End of the profiling rollout step range (**inclusive**, 0-indexed). E.g. start=2, end=4 profiles steps 2, 3, 4 | +| `--sglang-profile-steps` | None | Non-contiguous step list; takes precedence over start/end | +| `--sglang-profile-num-steps` | 3 | Number of SGLang forward steps to profile per rollout. -1 profiles the entire rollout step | +| `--sglang-profile-activities` | CPU GPU | Activities to profile | +| `--sglang-profile-by-stage` | False | Profile prefill / decode stages separately | +| `--sglang-profile-with-stack` | False | Record Python call stacks | +| `--sglang-profile-record-shapes` | False | Record tensor shape information | +| `--sglang-profile-output-dir` | None | Custom output directory. Defaults to `traces//sglang_trace` | + +### 2. Training Profiling (PyTorch Profiler) + +Profiles Actor training steps using `torch.profiler`, producing TensorBoard-compatible trace files. + +**Example usage** — profile steps 2, 3, 4 (start/end are both inclusive): + +```bash +python3 relax/entrypoints/train.py \ + --use-pytorch-profiler \ --profile-step-start 2 \ --profile-step-end 4 \ - --use-tensorboard \ - --tb-project-name /path/to/tb_logs \ + --tb-experiment-name my-experiment \ # ... other args ``` -You can specify multiple targets: `--profile-target train_overall train_actor train_log_probs`. +::: tip +`--profile-step-start` and `--profile-step-end` are both **inclusive** and represent **step offsets** from the current training launch, not absolute rollout IDs. The counter resets on checkpoint resumption. E.g. start=2, end=4 profiles steps 2, 3, 4 (3 steps). -### Profiler Detail Flags +Same inclusive semantics as `--sglang-profile-step-start/end`. +::: -Three flags control what additional information the profiler records: +**Detail flags**: | Flag | Effect | |---|---| @@ -34,18 +95,18 @@ Three flags control what additional information the profiler records: | `--profile-with-memory` | Track CUDA memory allocations/deallocations in the trace. Helps find memory spikes | | `--profile-with-flops` | Estimate FLOPs for each operator. Useful for calculating hardware utilization (MFU) | -Example with all detail flags: +**Full example**: ```bash python3 relax/entrypoints/train.py \ + --use-pytorch-profiler \ --profile-target train_overall \ --profile-step-start 2 \ --profile-step-end 4 \ --profile-with-stack \ --profile-with-memory \ --profile-with-flops \ - --use-tensorboard \ - --tb-project-name /path/to/tb_logs \ + --tb-experiment-name my-experiment \ # ... other args ``` @@ -53,10 +114,76 @@ python3 relax/entrypoints/train.py \ Enabling `--profile-with-stack` and `--profile-with-memory` adds overhead. Use them for diagnostic runs, not for production training. ::: -View the trace in TensorBoard: +### 3. GPU Memory Profiling + +Records CUDA memory allocation/deallocation history for diagnosing memory leaks and OOM issues. Automatically dumps a memory snapshot on OOM. + +**Minimal usage** — enable recording and proactively dump after step 2: ```bash -tensorboard --logdir /path/to/tb_logs +python3 relax/entrypoints/train.py \ + --record-memory-history \ + --memory-snapshot-num-steps 2 \ + --tb-experiment-name my-experiment \ + # ... other args +``` + +**Advanced parameters**: + +| Parameter | Default | Description | +|---|---|---| +| `--memory-snapshot-path` | snapshot.pickle | Snapshot filename suffix | +| `--memory-snapshot-dir` | None | Custom output directory. Defaults to `traces//memory_snapshot` | +| `--memory-snapshot-num-steps` | None | Proactively dump a snapshot after the specified number of steps (0-indexed; setting 3 dumps after step 2) | +| `--memory-recorder` | torch | Backend: `torch` (PyTorch built-in) or `memray` (requires `pip install memray`) | + +View snapshots: visit [PyTorch Memory Viz](https://pytorch.org/memory_viz) and drag in the generated `.pickle` file. + +### Combined Usage + +In practice, all three profiling tools can be enabled simultaneously for a comprehensive view. Here is a complete combined example: + +```bash +python3 relax/entrypoints/train.py \ + # --- SGLang Inference Profiling --- + --sglang-profile \ + --sglang-profile-step-start 2 \ + --sglang-profile-step-end 4 \ + # --- Training Profiling --- + --use-pytorch-profiler \ + --profile-step-start 2 \ + --profile-step-end 4 \ + # --- Memory Profiling --- + --record-memory-history \ + --memory-snapshot-num-steps 2 \ + # --- Experiment name (determines trace output directory) --- + --tb-experiment-name my-profiling-run \ + # ... other training args +``` + +The above configuration produces the following directory structure: + +``` +traces/my-profiling-run/ +├── sglang_trace/ # SGLang engine traces (subdirectory per rollout step) +│ ├── rollout_2/ +│ │ ├── engine0-...-TP-0.trace.json.gz +│ │ ├── engine0-...-TP-1.trace.json.gz +│ │ ├── engine1-...-TP-0.trace.json.gz +│ │ └── ... +│ ├── rollout_3/ +│ │ └── ... +│ └── rollout_4/ +│ ├── engine0-...-TP-0.trace.json.gz +│ └── ... +├── train_trace/ # Training traces +│ ├── train_overall_rank0_dp0_tp0_pp0.....pt.trace.json.gz +│ ├── train_overall_rank1_dp0_tp1_pp0.....pt.trace.json.gz +│ └── ... +└── memory_snapshot/ # Memory snapshots + ├── memory_snapshot_time..._rank0_dp0_tp0_pp0_snapshot.pickle + ├── memory_snapshot_time..._rank1_dp0_tp1_pp0_snapshot.pickle + └── ... ``` --- diff --git a/docs/en/guide/quick-start.md b/docs/en/guide/quick-start.md index 74ace281..d30478b3 100644 --- a/docs/en/guide/quick-start.md +++ b/docs/en/guide/quick-start.md @@ -126,6 +126,11 @@ python scripts/tools/process_avqa.py \ --input-dir /root/AVQA-R1-6K/AVQA_R1/train/omni_rl_format_train.json \ --output-dir /root/AVQA-R1-6K/AVQA_R1/train/omni_rl_format_train_convert.jsonl \ --md-dir /root/AVQA-R1-6K/AVQA_R1/train + +python scripts/tools/process_avqa.py \ + --input-dir /root/AVQA-R1-6K/AVQA_R1/valid/omni_rl_format_valid.json \ + --output-dir /root/AVQA-R1-6K/AVQA_R1/valid/small_valid.jsonl \ + --md-dir /root/AVQA-R1-6K/AVQA_R1/valid ``` The conversion script reads the raw JSON, extracts problem, options, image, and audio fields, and produces a `.jsonl` file with `prompt`, `image`, `audio`, and `label` columns. @@ -134,6 +139,10 @@ The conversion script reads the raw JSON, extracts problem, options, image, and ```bash hf download Qwen/Qwen3-Omni-30B-A3B-Instruct --local-dir /root/Qwen3-Omni-30B-A3B-Instruct + +# Qwen3-Omni ships its chat_template in a standalone chat_template.json that +# AutoTokenizer does not auto-load. Merge it into tokenizer_config.json (skipped if already present). +python -c "import json,sys; m=sys.argv[1]; p=f'{m}/tokenizer_config.json'; tc=json.load(open(p)); ('chat_template' in tc) or (tc.update(chat_template=json.load(open(f'{m}/chat_template.json'))['chat_template']) or json.dump(tc, open(p,'w'), indent=2, ensure_ascii=False))" /root/Qwen3-Omni-30B-A3B-Instruct ``` ### Launch Training @@ -184,6 +193,10 @@ The conversion script reads the original JSON file, extracts the question, optio ```bash hf download Qwen/Qwen3-Omni-30B-A3B-Instruct --local-dir /root/Qwen3-Omni-30B-A3B-Instruct + +# Qwen3-Omni ships its chat_template in a standalone chat_template.json that +# AutoTokenizer does not auto-load. Merge it into tokenizer_config.json (skipped if already present). +python -c "import json,sys; m=sys.argv[1]; p=f'{m}/tokenizer_config.json'; tc=json.load(open(p)); ('chat_template' in tc) or (tc.update(chat_template=json.load(open(f'{m}/chat_template.json'))['chat_template']) or json.dump(tc, open(p,'w'), indent=2, ensure_ascii=False))" /root/Qwen3-Omni-30B-A3B-Instruct ``` ### Launch Training diff --git a/docs/zh/examples/deepeyes.md b/docs/zh/examples/deepeyes.md index 950b151b..959799fe 100644 --- a/docs/zh/examples/deepeyes.md +++ b/docs/zh/examples/deepeyes.md @@ -49,25 +49,12 @@ HF Image dict 格式(`{"bytes": ...}`)被 Relax 的图像加载管线原生 ### 下载模型 ```bash -hf download Qwen/Qwen3-VL-4B-Instruct \ - --local-dir /root/Qwen3-VL-4B-Instruct +hf download Qwen/Qwen3-VL-30B-A3B-Thinking \ + --local-dir /root/Qwen3-VL-30B-A3B-Thinking ``` -完整配置使用 `Qwen/Qwen3-VL-30B-A3B-Thinking`。 - ## 快速开始 -### 4B 模型(8 GPU) - -```bash -export MODEL_DIR=/root -export DATA_DIR=/root -export SAVE_DIR=/root/save - -cd /root/Relax -bash examples/deepeyes/run_deepeyes_4b.sh -``` - ### 30B-A3B 模型(8 GPU,MoE) 完整配置需要 judge 模型进行奖励评分: @@ -91,7 +78,7 @@ bash examples/deepeyes/run_deepeyes.sh ```bash WORKING_DIR="./" RAY_ADDRESS=:6379 \ MODEL_DIR=/root DATA_DIR=/root SAVE_DIR=/root/save \ - bash -x scripts/entrypoint/ray-job.sh examples/deepeyes/run_deepeyes_4b.sh + bash -x scripts/entrypoint/ray-job.sh examples/deepeyes/run_deepeyes.sh ``` ## 架构 @@ -101,7 +88,6 @@ WORKING_DIR="./" RAY_ADDRESS=:6379 \ ``` examples/deepeyes/ ├── run_deepeyes.sh # 启动脚本(Qwen3-VL-30B-A3B,完整配置) -├── run_deepeyes_4b.sh # 启动脚本(Qwen3-VL-4B,轻量配置) ├── deepeyes_config.yaml # 任务配置(max_turns、环境路径) ├── rollout.py # 多轮 rollout 逻辑 ├── env_deepeyes.py # DeepEyes 工具使用环境 diff --git a/docs/zh/guide/configuration.md b/docs/zh/guide/configuration.md index bc38d308..a88fa6c1 100644 --- a/docs/zh/guide/configuration.md +++ b/docs/zh/guide/configuration.md @@ -88,6 +88,9 @@ | `--rollout-seed` | int | 42 | Rollout 的随机种子,用于打乱 Prompt 和随机采样 | | `--use-streaming-dataset` | flag | False | 使用流式数据集以节省内存 | | `--streaming-buffer-size` | int | 10000 | 流式数据集的缓冲区大小 | +| `--prefetch-chunk-size` | int | 32 | 每轮预取时分派到线程池的样本数。较大的值可以提高吞吐量但也会增加内存压力。仅在设置了 `--use-streaming-dataset` 且数据集包含多模态数据时生效 | +| `--prefetch-max-cached` | int | 256 | 预取缓存中保留的最大预加载样本数。缓存满时后台预取线程会暂停,直到消费者释放空间。设为 0 可禁用预取。仅在设置了 `--use-streaming-dataset` 且数据集包含多模态数据时生效 | +| `--prefetch-num-workers` | int | 1 | 预取缓冲区中用于 I/O 密集型媒体解码(视频/图像)的并行工作线程数。设为 1 可序列化所有解码操作(对 FFmpeg 非线程安全问题最安全)。较高值可提高并行度,但在某些平台上可能触发 EAGAIN 错误。仅在启用预取时生效 | | `--data-source-path` | str | `relax.engine.rollout.data_source.RolloutDataSourceWithBuffer` | Rollout 数据源类路径 | | `--start-rollout-id` | int | None | 起始 Rollout 步数。未设置时会尝试从 `--load` 的检查点中读取 | @@ -168,6 +171,16 @@ | 参数 | 类型 | 默认值 | 说明 | |------|------|--------|------| | `--sglang-mem-fraction-static` | float | - | SGLang 静态内存分配比例 | +| `--sglang-profile` | flag | False | 启用 SGLang 引擎的 torch profiling。在 Rollout 推理期间触发,每步保存 profile trace | +| `--sglang-profile-steps` | int (列表) | None | 指定要进行 SGLang profiling 的绝对 rollout step ID(0-indexed)列表。优先级高于 `--sglang-profile-step-start/end`。例如 `--sglang-profile-steps 3 10 50` | +| `--sglang-profile-step-start` | int | None | SGLang profiling 的起始 rollout step(**inclusive**,0-indexed)。与 `--sglang-profile-step-end` 配合指定连续范围。设置了 `--sglang-profile-steps` 时被忽略 | +| `--sglang-profile-step-end` | int | None | SGLang profiling 的结束 rollout step(**inclusive**,0-indexed)。与 `--sglang-profile-step-start` 配合指定连续范围。设置了 `--sglang-profile-steps` 时被忽略。例如 start=2, end=4 会采集 step 2, 3, 4 | +| `--sglang-profile-output-dir` | str | None | SGLang profile trace 的输出目录。默认使用 `traces//sglang_trace` | +| `--sglang-profile-num-steps` | int | 3 | 每轮 Rollout 中要 profile 的 SGLang 前向步数。-1 表示 profile 整个 Rollout 步,直到调用 `stop_profile` | +| `--sglang-profile-activities` | str (列表) | ["CPU", "GPU"] | 要 profile 的活动类型(例如 `CPU GPU`) | +| `--sglang-profile-by-stage` | flag | False | 按阶段(prefill/decode)分别进行 profile | +| `--sglang-profile-with-stack` | flag | False | 在 profile trace 中记录调用栈 | +| `--sglang-profile-record-shapes` | flag | False | 在 profile trace 中记录张量形状 | ### 自定义 Rollout 函数 @@ -498,7 +511,9 @@ Autoscaler YAML 配置详情请参见 [`relax/utils/autoscaler/autoscaler.yaml`] --- -## 调试参数 +## 调试与性能分析参数 + +### 调试 | 参数 | 类型 | 默认值 | 说明 | |------|------|--------|------| @@ -510,14 +525,33 @@ Autoscaler YAML 配置详情请参见 [`relax/utils/autoscaler/autoscaler.yaml`] | `--save-debug-train-data` | str | None | 保存训练数据,路径支持 `{rollout_id}` 占位符 | | `--dump-details` | str | None | 导出所有训练细节用于事后分析 | | `--check-weight-update-equal` | flag | False | 检查权重更新是否相等 | -| `--memory-snapshot-dir` | str | . | 内存快照目录 | -| `--memory-snapshot-num-steps` | int | None | 内存快照步数 | +| `--enable-cuda-memory-check` | flag | False | 在底层 NCCL 通信调用周围启用内存检查。在每次集合通信前记录可用 GPU 显存,通信失败时将内存信息附加到异常中 | + +### 训练性能 Profiling + +以下参数控制训练过程的 PyTorch Profiler 采集。Trace 文件默认保存到 `traces//train_trace/` 目录下。 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `--use-pytorch-profiler` | flag | False | 启用 PyTorch 内置 profiler 记录训练步骤的 CUDA kernel、CPU op 和通信操作(来自 Megatron) | +| `--profile-step-start` | int | 10 | 开始 profiling 的步数偏移(**inclusive**,来自 Megatron)。指从本次训练启动后的第 N 步开始采集,非绝对 rollout ID;断点续训时计数从 0 重新开始 | +| `--profile-step-end` | int | 12 | 停止 profiling 的步数偏移(**inclusive**,来自 Megatron)。含义同上。例如 start=10, end=12 会采集 step 10, 11, 12(共 3 步) | | `--profile-target` | str (列表) | train_overall | 性能分析目标:`train_overall`、`train_actor`、`train_log_probs` | | `--profile-with-stack` | flag | False | 在 profiler trace 中记录调用栈信息 | | `--profile-with-memory` | flag | False | 在 profiler trace 中记录内存信息 | | `--profile-with-flops` | flag | False | 在 profiler trace 中估算 FLOPs | -| `--memory-recorder` | str | torch | 内存记录器:`torch`、`memray` | -| `--enable-cuda-memory-check` | flag | False | 在底层 NCCL 通信调用周围启用内存检查。在每次集合通信前记录可用 GPU 显存,通信失败时将内存信息附加到异常中 | + +### GPU 内存 Profiling + +以下参数控制 GPU 内存快照采集,用于诊断显存泄漏和 OOM 问题。Snapshot 文件可用 PyTorch Memory Viz 工具(`torch.cuda.memory._viz`)查看。 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `--record-memory-history` | flag | False | 启用 CUDA 内存分配历史记录(来自 Megatron)。开启后会记录每次分配/释放的调用栈和张量信息,并在发生 OOM 时自动 dump snapshot | +| `--memory-snapshot-path` | str | snapshot.pickle | 内存快照文件名(来自 Megatron) | +| `--memory-snapshot-dir` | str | None | 内存快照保存目录。默认使用 `traces//memory_snapshot` | +| `--memory-snapshot-num-steps` | int | None | 在指定步数后主动 dump 内存快照(0-indexed,即设为 3 表示在第 2 步后 dump) | +| `--memory-recorder` | str | torch | 内存记录器后端:`torch`(PyTorch 内置)、`memray`(需要 `pip install memray`) | ### 网络 diff --git a/docs/zh/guide/customize-training.md b/docs/zh/guide/customize-training.md index 5cd75deb..bbde0fae 100644 --- a/docs/zh/guide/customize-training.md +++ b/docs/zh/guide/customize-training.md @@ -61,7 +61,7 @@ MODEL_ARGS=( #### 2. Megatron Bridge 模型适配 -Relax 通过 [Megatron Bridge](https://github.com/redai-infra/megatron-bridge) 实现 HF ↔ Megatron 的自动权重转换。若您的模型尚未被 Megatron Bridge 支持,需要先在 Megatron Bridge 侧完成适配,详见其项目文档。 +Relax 通过 [Megatron Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) 实现 HF ↔ Megatron 的自动权重转换。若您的模型尚未被 Megatron Bridge 支持,需要先在 Megatron Bridge 侧完成适配,详见其项目文档。 ::: tip AI 辅助接入 本项目提供了 Codewiz skill `model-integration`(位于 `.codewiz/skills/model-integration/`),涵盖 Bridge / Raw / FSDP 三种后端的完整接入流程、权重转换器编写规范、TP 分片逻辑及常见陷阱,可在 Codewiz 中通过 `invoke skill model-integration` 调用以获得逐步指导。 diff --git a/docs/zh/guide/installation.md b/docs/zh/guide/installation.md index c697222c..2d29ddb2 100644 --- a/docs/zh/guide/installation.md +++ b/docs/zh/guide/installation.md @@ -87,10 +87,17 @@ export MEGATRON="your megatron path" export PYTHONPATH=your_megatron_path:$PYTHONPATH ``` -此外 Relax 依赖 megatron bridge 进行权重转换,若需要转换请安装: +此外 Relax 依赖 [Megatron Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) 进行权重转换。安装方式参考 `docker/Dockerfile`,将 Bridge 源码与 Megatron-LM submodule 合并到同一目录后加入 `PYTHONPATH`: ```bash -pip install git+https://github.com/redai-infra/megatron-bridge.git@relax/dev --no-build-isolation --no-deps +export MEGATRON_BRIDGE_COMMIT=2faedbf6fe3c422835a44b2b360cadcb2a116a54 +git clone https://github.com/NVIDIA-NeMo/Megatron-Bridge.git +cd Megatron-Bridge && git checkout ${MEGATRON_BRIDGE_COMMIT} && \ + git submodule update --init --recursive && ./scripts/switch_mcore.sh dev +mkdir -p /your/path/Megatron-LM +cp -r src/megatron /your/path/Megatron-LM/ +rsync -avP 3rdparty/Megatron-LM/megatron/ /your/path/Megatron-LM/megatron/ +export PYTHONPATH=/your/path/Megatron-LM:$PYTHONPATH ``` ## 下一步 diff --git a/docs/zh/guide/performance-tuning.md b/docs/zh/guide/performance-tuning.md index f2c2b08e..a7b5d9f9 100644 --- a/docs/zh/guide/performance-tuning.md +++ b/docs/zh/guide/performance-tuning.md @@ -6,27 +6,88 @@ Relax 训练吞吐量最大化实践指南。本文提到的所有参数均可 ## 性能分析 -调优前先定位瓶颈。Relax 集成了 PyTorch Profiler,可生成兼容 TensorBoard 的 trace 文件。 +调优前先定位瓶颈。Relax 内置三套互补的 profiling 工具,覆盖 **推理引擎**、**训练后端** 和 **GPU 内存** 三个维度。所有 trace 文件默认保存在 `traces//` 目录下,按子目录区分: -### 启用 Profiler +| 工具 | 目标 | 默认输出目录 | 查看方式 | +|---|---|---|---| +| SGLang Profiling | Rollout 推理的 CUDA kernel / 算子分析 | `traces//sglang_trace/` | TensorBoard or `https://ui.perfetto.dev/` | +| Training Profiling | Actor 训练 / log-probs 计算的算子分析 | `traces//train_trace/` | TensorBoard or `https://ui.perfetto.dev/` | +| Memory Profiling | GPU 显存分配历史,OOM 诊断 | `traces//memory_snapshot/` | [PyTorch Memory Viz](https://pytorch.org/memory_viz) | -Profiler 通过 `--profile-step-start` 和 `--profile-step-end`(Megatron 原生参数)配合 `--profile-target` 控制: +### Trace 文件命名规则 + +- **训练 trace** 文件名包含 `rank{global}_dp{dp}_tp{tp}_pp{pp}` 标识,例如 `train_overall_rank0_dp0_tp0_pp0.1713780123.pt.trace.json.gz` +- **内存快照** 文件名同样包含 rank 标识,例如 `memory_snapshot_time1713780123_rank0_dp0_tp0_pp0_snapshot.pickle` +- **SGLang trace** 文件以 `engine{i}` 为前缀区分不同引擎实例,例如 `engine0-1713780123-TP-0.trace.json.gz` + +### 1. SGLang 推理 Profiling + +对 Rollout 阶段所有 SGLang 引擎进行 `torch.profiler` 采集。通过 HTTP API `/start_profile` 和 `/stop_profile` 触发,不影响训练侧的 profiler。 + +**示例用法** — 每个 rollout step 都采集: ```bash python3 relax/entrypoints/train.py \ - --profile-target train_overall \ + --sglang-profile \ + --tb-experiment-name my-experiment \ + # ... 其他参数 +``` + +**指定 rollout step 范围** — 仅在 step 2、3、4 采集(start/end 均 inclusive,推荐用法,避免大量 trace 文件): + +```bash +python3 relax/entrypoints/train.py \ + --sglang-profile \ + --sglang-profile-step-start 2 \ + --sglang-profile-step-end 4 \ + --tb-experiment-name my-experiment \ + # ... 其他参数 +``` + +也可以用 `--sglang-profile-steps` 指定不连续的 step 列表(优先级高于 start/end): + +```bash +--sglang-profile-steps 2 5 10 +``` + +所有 step 参数均使用 **绝对 rollout ID**(0-indexed),即第 0 轮、第 1 轮 ... 与 `--start-rollout-id` 无关。 + +**进阶参数**: + +| 参数 | 默认值 | 说明 | +|---|---|---| +| `--sglang-profile-step-start` | None | profiling 起始 rollout step(**inclusive**,0-indexed) | +| `--sglang-profile-step-end` | None | profiling 结束 rollout step(**inclusive**,0-indexed)。例如 start=2, end=4 采集 step 2, 3, 4 | +| `--sglang-profile-steps` | None | 不连续 step 列表,优先级高于 start/end | +| `--sglang-profile-num-steps` | 3 | 每轮 Rollout 中采集的 SGLang 前向步数。-1 表示整轮采集 | +| `--sglang-profile-activities` | CPU GPU | 要采集的活动类型 | +| `--sglang-profile-by-stage` | False | 按 prefill / decode 阶段分别采集 | +| `--sglang-profile-with-stack` | False | 记录 Python 调用栈 | +| `--sglang-profile-record-shapes` | False | 记录张量形状信息 | +| `--sglang-profile-output-dir` | None | 自定义输出目录。默认 `traces//sglang_trace` | + +### 2. 训练 Profiling(PyTorch Profiler) + +对 Actor 训练步骤进行 `torch.profiler` 采集,生成兼容 TensorBoard 的 trace 文件。 + +**示例用法** — 采集第 2、3、4 步(start/end 均 inclusive): + +```bash +python3 relax/entrypoints/train.py \ + --use-pytorch-profiler \ --profile-step-start 2 \ --profile-step-end 4 \ - --use-tensorboard \ - --tb-project-name /path/to/tb_logs \ + --tb-experiment-name my-experiment \ # ... 其他参数 ``` -可以指定多个分析目标:`--profile-target train_overall train_actor train_log_probs`。 +::: tip +`--profile-step-start` 和 `--profile-step-end` 均为 **inclusive**,是从本次训练启动后的 **步数偏移**,不是绝对 rollout ID。断点续训时计数从 0 重新开始。例如 start=2, end=4 采集 step 2, 3, 4(共 3 步)。 -### Profiler 详细信息标志 +语义与 `--sglang-profile-step-start/end` 相同(两端均 inclusive)。 +::: -以下三个标志控制 Profiler 记录的额外信息: +**详细信息标志**: | 标志 | 作用 | |---|---| @@ -34,18 +95,18 @@ python3 relax/entrypoints/train.py \ | `--profile-with-memory` | 在 trace 中跟踪 CUDA 显存分配/释放。用于发现显存尖峰 | | `--profile-with-flops` | 估算每个算子的 FLOPs。用于计算硬件利用率 (MFU) | -启用全部详细信息标志的示例: +**完整示例**: ```bash python3 relax/entrypoints/train.py \ + --use-pytorch-profiler \ --profile-target train_overall \ --profile-step-start 2 \ --profile-step-end 4 \ --profile-with-stack \ --profile-with-memory \ --profile-with-flops \ - --use-tensorboard \ - --tb-project-name /path/to/tb_logs \ + --tb-experiment-name my-experiment \ # ... 其他参数 ``` @@ -53,10 +114,76 @@ python3 relax/entrypoints/train.py \ 启用 `--profile-with-stack` 和 `--profile-with-memory` 会增加额外开销。建议仅在诊断时使用,不用于生产训练。 ::: -使用 TensorBoard 查看 trace: +### 3. GPU 内存 Profiling + +记录 CUDA 显存分配/释放历史,用于诊断显存泄漏和 OOM 问题。在发生 OOM 时会自动 dump 内存快照。 + +**最小用法** — 开启记录 + 在第 2 步后主动 dump: ```bash -tensorboard --logdir /path/to/tb_logs +python3 relax/entrypoints/train.py \ + --record-memory-history \ + --memory-snapshot-num-steps 2 \ + --tb-experiment-name my-experiment \ + # ... 其他参数 +``` + +**进阶参数**: + +| 参数 | 默认值 | 说明 | +|---|---|---| +| `--memory-snapshot-path` | snapshot.pickle | 快照文件名后缀 | +| `--memory-snapshot-dir` | None | 自定义输出目录。默认 `traces//memory_snapshot` | +| `--memory-snapshot-num-steps` | None | 在指定步数后主动 dump 快照(0-indexed,设 3 表示第 2 步后 dump) | +| `--memory-recorder` | torch | 后端选择:`torch`(PyTorch 内置)、`memray`(需 `pip install memray`) | + +查看快照:访问 [PyTorch Memory Viz](https://pytorch.org/memory_viz),拖入生成的 `.pickle` 文件。 + +### 三种 Profiling 联合使用 + +实际诊断中,可同时开启三种 profiling 以获得全面视图。以下是一个完整的联合使用示例: + +```bash +python3 relax/entrypoints/train.py \ + # --- SGLang 推理 Profiling --- + --sglang-profile \ + --sglang-profile-step-start 2 \ + --sglang-profile-step-end 4 \ + # --- 训练 Profiling --- + --use-pytorch-profiler \ + --profile-step-start 2 \ + --profile-step-end 4 \ + # --- 内存 Profiling --- + --record-memory-history \ + --memory-snapshot-num-steps 2 \ + # --- 实验名(决定 trace 输出目录)--- + --tb-experiment-name my-profiling-run \ + # ... 其他训练参数 +``` + +上述配置会产出如下目录结构: + +``` +traces/my-profiling-run/ +├── sglang_trace/ # SGLang 引擎 trace(按 rollout step 分目录) +│ ├── rollout_2/ +│ │ ├── engine0-...-TP-0.trace.json.gz +│ │ ├── engine0-...-TP-1.trace.json.gz +│ │ ├── engine1-...-TP-0.trace.json.gz +│ │ └── ... +│ ├── rollout_3/ +│ │ └── ... +│ └── rollout_4/ +│ ├── engine0-...-TP-0.trace.json.gz +│ └── ... +├── train_trace/ # 训练 trace +│ ├── train_overall_rank0_dp0_tp0_pp0.....pt.trace.json.gz +│ ├── train_overall_rank1_dp0_tp1_pp0.....pt.trace.json.gz +│ └── ... +└── memory_snapshot/ # 内存快照 + ├── memory_snapshot_time..._rank0_dp0_tp0_pp0_snapshot.pickle + ├── memory_snapshot_time..._rank1_dp0_tp1_pp0_snapshot.pickle + └── ... ``` --- diff --git a/docs/zh/guide/quick-start.md b/docs/zh/guide/quick-start.md index a5acab1d..571fa413 100644 --- a/docs/zh/guide/quick-start.md +++ b/docs/zh/guide/quick-start.md @@ -126,6 +126,11 @@ python scripts/tools/process_avqa.py \ --input-dir /root/AVQA-R1-6K/AVQA_R1/train/omni_rl_format_train.json \ --output-dir /root/AVQA-R1-6K/AVQA_R1/train/omni_rl_format_train_convert.jsonl \ --md-dir /root/AVQA-R1-6K/AVQA_R1/train + +python scripts/tools/process_avqa.py \ + --input-dir /root/AVQA-R1-6K/AVQA_R1/valid/omni_rl_format_valid.json \ + --output-dir /root/AVQA-R1-6K/AVQA_R1/valid/small_valid.jsonl \ + --md-dir /root/AVQA-R1-6K/AVQA_R1/valid ``` 转换脚本读取原始 JSON 文件,提取问题、选项、图片和音频字段,生成包含 `prompt`、`image`、`audio` 和 `label` 列的 `.jsonl` 文件。 @@ -134,6 +139,10 @@ python scripts/tools/process_avqa.py \ ```bash hf download Qwen/Qwen3-Omni-30B-A3B-Instruct --local-dir /root/Qwen3-Omni-30B-A3B-Instruct + +# Qwen3-Omni 的 chat_template 单独存放在 chat_template.json 中, +# AutoTokenizer 不会自动加载,需要合并到 tokenizer_config.json(已存在则跳过) +python -c "import json,sys; m=sys.argv[1]; p=f'{m}/tokenizer_config.json'; tc=json.load(open(p)); ('chat_template' in tc) or (tc.update(chat_template=json.load(open(f'{m}/chat_template.json'))['chat_template']) or json.dump(tc, open(p,'w'), indent=2, ensure_ascii=False))" /root/Qwen3-Omni-30B-A3B-Instruct ``` ### 启动训练 @@ -184,6 +193,10 @@ python scripts/tools/process_nextqa.py \ ```bash hf download Qwen/Qwen3-Omni-30B-A3B-Instruct --local-dir /root/Qwen3-Omni-30B-A3B-Instruct + +# Qwen3-Omni 的 chat_template 单独存放在 chat_template.json 中, +# AutoTokenizer 不会自动加载,需要合并到 tokenizer_config.json(已存在则跳过) +python -c "import json,sys; m=sys.argv[1]; p=f'{m}/tokenizer_config.json'; tc=json.load(open(p)); ('chat_template' in tc) or (tc.update(chat_template=json.load(open(f'{m}/chat_template.json'))['chat_template']) or json.dump(tc, open(p,'w'), indent=2, ensure_ascii=False))" /root/Qwen3-Omni-30B-A3B-Instruct ``` ### 启动训练 diff --git a/examples/deepeyes/env_deepeyes.py b/examples/deepeyes/env_deepeyes.py index dcf9fddf..07763d4a 100644 --- a/examples/deepeyes/env_deepeyes.py +++ b/examples/deepeyes/env_deepeyes.py @@ -27,12 +27,16 @@ class DeepeyesEnv(BaseInteractionEnv): MIN_DIMENSION = 28 - def __init__(self, *, max_turns: int | None = None, image=None): + def __init__(self, *, max_turns: int | None = None, image=None, normalize_bbox: bool = True): self.max_turns = max_turns self.turn = 0 self.tool_calls: list[dict[str, Any]] = [] self.current_image = image self.origin_image = image + # Whether to convert bbox coordinates from normalized [0, 1000] to absolute pixels. + # Qwen-VL / Qwen2-VL / Qwen3-VL output 0-1000 normalized coords → set True (default). + # Qwen2.5-VL outputs absolute pixel coords → set False. + self.normalize_bbox = normalize_bbox def reset(self): self.turn = 0 @@ -119,13 +123,21 @@ def _maybe_resize_bbox(self, bbox_2d: list[float]) -> Optional[list[float]]: image_height = self.current_image.height left, top, right, bottom = bbox_2d - # 1. Clamp the initial bounding box to the image dimensions. + # 1. Convert normalized [0, 1000] coordinates to absolute pixel coordinates. + # Qwen-VL / Qwen2-VL / Qwen3-VL use 0-1000 normalized coords; Qwen2.5-VL uses absolute pixels. + if self.normalize_bbox: + left = left / 1000.0 * image_width + top = top / 1000.0 * image_height + right = right / 1000.0 * image_width + bottom = bottom / 1000.0 * image_height + + # 2. Clamp the bounding box to the image dimensions. left = max(0.0, float(left)) top = max(0.0, float(top)) right = min(float(image_width), float(right)) bottom = min(float(image_height), float(bottom)) - # 2. If clamped bbox is invalid, return immediately. + # 3. If clamped bbox is invalid, return immediately. if not self._validate_bbox(left, top, right, bottom): return None @@ -133,7 +145,7 @@ def _maybe_resize_bbox(self, bbox_2d: list[float]) -> Optional[list[float]]: height = bottom - top width = right - left - # 3. If the box is too small, attempt to resize it. + # 4. If the box is too small, attempt to resize it. if height < self.MIN_DIMENSION or width < self.MIN_DIMENSION: logger.info(f"Bbox {width}x{height} is smaller than {self.MIN_DIMENSION}, attempting resize.") center_x = (left + right) / 2.0 @@ -182,7 +194,7 @@ def _maybe_resize_bbox(self, bbox_2d: list[float]) -> Optional[list[float]]: # Use floor and ceil for final integer coordinates. current_bbox = [floor(new_left), floor(new_top), ceil(new_right), ceil(new_bottom)] - # 4. Final validation on the resulting bounding box (either original or resized). + # 5. Final validation on the resulting bounding box (either original or resized). final_left, final_top, final_right, final_bottom = current_bbox if not self._validate_bbox(final_left, final_top, final_right, final_bottom): logger.warning(f"Final bbox is invalid after processing: {current_bbox}") @@ -288,7 +300,8 @@ def build_env(sample: Sample | None = None, args: Any | None = None, **_: Any) - max_turns = args.max_turns if max_turns is None: raise ValueError("max_turns must be set via --custom-config-path in the custom config file.") + normalize_bbox = getattr(args, "normalize_bbox", True) image = _extract_initial_image(sample) if image is None: logger.warning("No image found in sample.multimodal_inputs or metadata.") - return DeepeyesEnv(max_turns=max_turns, image=image) + return DeepeyesEnv(max_turns=max_turns, image=image, normalize_bbox=normalize_bbox) diff --git a/examples/deepeyes/qwen_vl.py b/examples/deepeyes/qwen_vl.py index 449e94d3..59e0abff 100644 --- a/examples/deepeyes/qwen_vl.py +++ b/examples/deepeyes/qwen_vl.py @@ -279,6 +279,53 @@ def get_mm_data(self, prompt, embeddings, img_grid_thw): "mrope_position_delta": mrope_position_delta, } + @staticmethod + def _strip_image_token(input_ids, image_token_id: int = 151655): + """Collapse consecutive ``<|image_pad|>`` tokens into a single + placeholder. + + Transform:: + + <|vision_start|><|image_pad|><|image_pad|>...<|image_pad|><|vision_end|> + -> <|vision_start|><|image_pad|><|vision_end|> + + Why: the caller may pass pre-tokenized ``input_ids`` in which each image + has already been expanded into N ``<|image_pad|>`` tokens (one per visual + patch). However ``load_mm_data`` downstream expects exactly *one* + ``<|image_pad|>`` placeholder per image and re-expands it itself based on + the actual patch count. Without this collapse the pipeline would see + ``N x M`` image-pad tokens and miscount positions, breaking mrope + bookkeeping. Raw text prompts are returned unchanged. + + Args: + input_ids: List of token ids, or any non-list value (passed through). + image_token_id: Id of ``<|image_pad|>`` (151655 for Qwen-VL family). + + Returns: + ``input_ids`` with each run of consecutive ``image_token_id`` reduced + to a single occurrence, or the input untouched if it is not a list. + """ + # Raw text prompts (str) and other non-list inputs require no rewrite. + if not isinstance(input_ids, list): + return input_ids + + import numpy as np + + input_id_arr = np.array(input_ids) + + # mask[i] == True means "keep token at index i"; start by keeping all. + mask = np.ones(len(input_id_arr), dtype=bool) + + # Boolean array marking every <|image_pad|> position. + is_value = input_id_arr == image_token_id + + # A token at index i is a redundant duplicate iff both it and its left + # neighbour are <|image_pad|>. Dropping those keeps the first occurrence + # in each run and removes the rest. Index 0 has no left neighbour so it + # is always kept (mask[0] stays True). + mask[1:] &= ~(is_value[1:] & is_value[:-1]) + return input_id_arr[mask].tolist() + async def process_mm_data_async( self, image_data: List[Union[str, bytes]], @@ -299,7 +346,7 @@ async def process_mm_data_async( original_input_ids = input_text base_output = self.load_mm_data( - prompt=input_text, + prompt=self._strip_image_token(input_text), image_data=image_data, video_data=request_obj.video_data, audio_data=request_obj.audio_data, diff --git a/examples/deepeyes/reward_deepeyes.py b/examples/deepeyes/reward_deepeyes.py index 4c431741..47a36488 100644 --- a/examples/deepeyes/reward_deepeyes.py +++ b/examples/deepeyes/reward_deepeyes.py @@ -24,49 +24,49 @@ def get_gpt4_score_ICE(): [Question]: Is the countertop tan or blue? [Standard Answer]: The countertop is tan. [Model_answer] : tan -Judgement: 1 +Judgment: 1 """ # noqa example_2 = """ [Question]: On which side of the picture is the barrier? [Standard Answer]: The barrier is on the left side of the picture. [Model_answer] : left -Judgement: 1 +Judgment: 1 """ # noqa example_3 = """ [Question]: Is the kite brown and large? [Standard Answer]: Yes, the kite is brown and large. [Model_answer] : Yes -Judgement: 1 +Judgment: 1 """ # noqa example_4 = """ [Question]: Are the spots on a giraffe? [Standard Answer]: No, the spots are on a banana. [Model_answer] : no -Judgement: 1 +Judgment: 1 """ # noqa example_5 = """ [Question]: Who is wearing pants? [Standard Answer]: The boy is wearing pants. [Model_answer] : The person in the picture is wearing pants. -Judgement: 1 +Judgment: 1 """ # noqa example_6 = """ [Question]: Is the man phone both blue and closed? [Standard Answer]: Yes, the man phone is both blue and closed. [Model_answer] : No. -Judgement: 0 +Judgment: 0 """ # noqa example_7 = """ [Question]: What color is the towel in the center of the picture? [Standard Answer]: The towel in the center of the picture is blue. [Model_answer] : The towel in the center of the picture is pink. -Judgement: 0 +Judgment: 0 """ # noqa return [example_1, example_2, example_3, example_4, example_5, example_6, example_7] @@ -76,7 +76,7 @@ def get_chat_template(): chat_template = """ Below are two answers to a question. Question is [Question], [Standard Answer] is the standard answer to the question, and [Model_answer] is the answer extracted from a model's output to this question. Determine whether these two answers are consistent. Note that [Model Answer] is consistent with [Standard Answer] whenever they are essentially the same. If the meaning is expressed in the same way, it is considered consistent, for example, 'pink' and 'it is pink'. -If they are consistent, Judement is 1; if they are different, Judement is 0. Just output Judement and don't output anything else.\n\n +If they are consistent, Judgment is 1; if they are different, Judgment is 0. Just output Judgment and don't output anything else.\n\n """ return chat_template @@ -91,7 +91,7 @@ def get_prompt(predict_str, ground_truth, question): [Question]: {question} [Standard Answer]: {ground_truth} [Model_answer] : {predict_str} -Judgement:""" +Judgment:""" full_prompt = f"{demo_prompt}{test_prompt}" return full_prompt @@ -189,8 +189,8 @@ def compute_score(predict_str: str, ground_truth: str, extra_info: dict | None = response = "error" # print(response) - if "Judgement:" in response: - response = response.split("Judgement:")[-1].strip() + if "Judgment:" in response: + response = response.split("Judgment:")[-1].strip() if "1" in response: acc_reward = 1.0 elif "0" in response: diff --git a/examples/deepeyes/run_deepeyes_qwen35_9B_async.sh b/examples/deepeyes/run_deepeyes_qwen35_9B_async.sh new file mode 100755 index 00000000..67870d77 --- /dev/null +++ b/examples/deepeyes/run_deepeyes_qwen35_9B_async.sh @@ -0,0 +1,227 @@ +#!/bin/bash + +# Copyright (c) 2026 Relax Authors. All Rights Reserved. +# +# Qwen3.5-9B 8xGPU single-node fully-async DeepEyes training script. +# +# Resource layout (8 GPUs, fully-async): +# actor: 4 GPUs (TP=4) +# rollout: 2 GPUs (1 engine × 2 GPUs) +# reference: 1 GPU (TP=1, weight-only) +# actor_fwd: 1 GPU +# +# Usage: +# MODEL_DIR=/path/to/models DATA_DIR=/path/to/data SAVE_DIR=/path/to/save \ +# bash examples/deepeyes/run_deepeyes_qwen35_9B_async.sh + +set -ex +set -o pipefail + +############################################################################### +# ENVIRONMENT # +############################################################################### + +TIMESTAMP=$(date "+%Y-%m-%d-%H:%M:%S") + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +# Auto-source local environment when not launched via an external entrypoint +if [ -z "${RELAX_ENTRYPOINT_MODE:-}" ]; then + source "${SCRIPT_DIR}/../../scripts/entrypoint/local.sh" +fi +source "${MODEL_CONFIG_DIR}/qwen35-9B.sh" + +############################################################################### +# DIRS # +############################################################################### + +PROJECT_NAME="${PROJECT_NAME:=Relax/dev/deepeyes}" +EXP_NAME="qwen35-9B-deepeyes-async-${TIMESTAMP}" + +# Require MODEL_DIR, DATA_DIR, SAVE_DIR from environment or set defaults +if [ -z "${MODEL_DIR:-}" ] || [ -z "${DATA_DIR:-}" ] || [ -z "${SAVE_DIR:-}" ]; then + echo "ERROR: MODEL_DIR, DATA_DIR, and SAVE_DIR must be set." + echo "Example: MODEL_DIR=/path/to/models DATA_DIR=/path/to/data SAVE_DIR=/path/to/save bash $0" + exit 1 +fi +mkdir -p ${SAVE_DIR} + +############################################################################### +# JUDGE MODEL API # +############################################################################### + +source "${SCRIPT_DIR}/sglang_judge_service.sh" + +############################################################################### +# MODEL CONFIG # +############################################################################### + +CKPT_ARGS=( + --hf-checkpoint ${MODEL_DIR}/Qwen3.5-9B + --ref-load ${MODEL_DIR}/Qwen3.5-9B + --save ${SAVE_DIR}/Qwen3.5-9B-DeepEyes-Checkpoint + --megatron-to-hf-mode bridge + --save-interval 100 + --max-actor-ckpt-to-keep 1 +) + +############################################################################### +# DATASETS # +############################################################################### + +TRAIN_FILES=( + "'${DATA_DIR}/deepeyes-v1/data_0.1.2_visual_toolbox_v2.parquet@[0:5000]'" + "'${DATA_DIR}/deepeyes-v1/data_v0.8_visual_toolbox_v2.parquet@[0:5000]'" +) +TEST_FILES=("${DATA_DIR}/deepeyes-v1/data_thinklite_reasoning_acc.parquet@[0:256]") +PROMPT_SET="[$(IFS=,; echo "${TRAIN_FILES[*]}")]" + +############################################################################### +# ROLLOUT CONFIG # +############################################################################### + +NUM_ROLLOUT="${NUM_ROLLOUT:=2000}" + +ROLLOUT_ARGS=( + --prompt-data "${PROMPT_SET}" + --input-key prompt + --label-key reward_model + --multimodal-keys '{"image":"images"}' + --reward-key score + --metadata-key extra_info + --apply-chat-template + --custom-generate-function-path examples.deepeyes.rollout.generate + --custom-rm-path examples.deepeyes.reward_deepeyes.reward_func + --custom-config-path examples/deepeyes/deepeyes_config.yaml + --num-rollout ${NUM_ROLLOUT} + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 2048 + --rollout-max-prompt-len 2048 + --rollout-temperature 1 + --global-batch-size 256 + --use-fault-tolerance + --rollout-shuffle + --use-streaming-dataset +) + +############################################################################### +# EVAL CONFIG # +############################################################################### + +EVAL_ARGS=( + --eval-interval 100 + --eval-prompt-data vstar ${TEST_FILES} + --n-samples-per-eval-prompt 8 + --eval-max-response-len 2048 + --eval-top-p 0.7 +) + +############################################################################### +# ALGORITHM CONFIG # +############################################################################### + +GRPO_ARGS=( + --advantage-estimator grpo + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 + --eps-clip-c 3 + --use-tis +) + +############################################################################### +# OPTIMIZER CONFIG # +############################################################################### + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +############################################################################### +# SGLANG CONFIG # +############################################################################### + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 2 + --sglang-mem-fraction-static 0.6 +) + +############################################################################### +# LOGGING CONFIG # +############################################################################### + +LOG_ARGS=( + --use-clearml + --use-metrics-service + --tb-project-name ${PROJECT_NAME} + --tb-experiment-name ${EXP_NAME} +) + +############################################################################### +# MEGATRON CONFIG # +############################################################################### + +MEGATRON_ARGS=( + --tensor-model-parallel-size 4 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 + --no-rope-fusion + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +############################################################################### +# RESOURCE CONFIG # +############################################################################### + +# Fully-async: actor(4 GPU) + rollout(2 GPU) + reference(1 GPU) + actor_fwd(1 GPU) = 8 GPU +RAY_RESOURCE_ARGS=( + --resource '{"actor": [1, 4], "rollout": [1, 2], "reference": [1, 1], "actor_fwd": [1, 1], "advantages": [1, 0]}' + --max-staleness 2 + --num-data-storage-units 1 + --num-iters-per-train-update 8 + --ref-actor-config '{"tensor_model_parallel_size": 1, "max_tokens_per_gpu": 16384, "sequence_parallel": false, "only_load_weight": true}' + --fully-async + --use-health-check +) + +############################################################################### +# LAUNCH JOB # +############################################################################### + +mkdir -p logs + +ray job submit ${RAY_NO_WAIT:+--no-wait} --address="http://127.0.0.1:8265" \ + -- python3 -m relax.entrypoints.train \ + "${RAY_RESOURCE_ARGS[@]}" \ + "${MODEL_ARGS[@]}" \ + "${CKPT_ARGS[@]}" \ + "${ROLLOUT_ARGS[@]}" \ + "${GRPO_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${SGLANG_ARGS[@]}" \ + "${LOG_ARGS[@]}" \ + "${MEGATRON_ARGS[@]}" \ + "${EVAL_ARGS[@]}" \ + 2>&1 | tee logs/${EXP_NAME}.log diff --git a/examples/deepeyes/run_deepeyes_r3.sh b/examples/deepeyes/run_deepeyes_r3.sh index e817ea4f..6b237634 100644 --- a/examples/deepeyes/run_deepeyes_r3.sh +++ b/examples/deepeyes/run_deepeyes_r3.sh @@ -60,11 +60,11 @@ CKPT_ARGS=( # DATASETS # ############################################################################### -TRAIN_FILES=() -for i in {0..9}; do - TRAIN_FILES+=("'${DATA_DIR}/deepeyes/train/v0.1.2.parquet/partition=${i}/3ce23f4945e8498085ac5f72f0afc133-0.parquet'") -done -TEST_FILES=("${DATA_DIR}/deepeyes/test.parquet") +TRAIN_FILES=( + "'${DATA_DIR}/deepeyes-v1/data_0.1.2_visual_toolbox_v2.parquet@[0:5000]'" + "'${DATA_DIR}/deepeyes-v1/data_v0.8_visual_toolbox_v2.parquet@[0:5000]'" +) +TEST_FILES=("${DATA_DIR}/deepeyes-v1/data_thinklite_reasoning_acc.parquet@[0:256]") PROMPT_SET="[$(IFS=,; echo "${TRAIN_FILES[*]}")]" ############################################################################### @@ -109,6 +109,7 @@ ROUTING_REPLAY_ARGS=( ############################################################################### EVAL_ARGS=( + --skip-eval-before-train --eval-interval 100 --eval-prompt-data vstar ${TEST_FILES} --n-samples-per-eval-prompt 8 diff --git a/relax/backends/megatron/__init__.py b/relax/backends/megatron/__init__.py index eff76ce7..00e57ec6 100644 --- a/relax/backends/megatron/__init__.py +++ b/relax/backends/megatron/__init__.py @@ -2,7 +2,6 @@ import logging -import torch from relax.utils.external.torch_memory_saver import TORCH_MEMORY_SAVER_AVAILABLE, torch_memory_saver @@ -16,8 +15,8 @@ def new_init(self, *args, **kwargs): if TORCH_MEMORY_SAVER_AVAILABLE and torch_memory_saver._impl is not None: torch_memory_saver._impl._binary_wrapper.cdll.tms_set_interesting_region(False) old_init(self, *args, **kwargs) - torch.cuda.synchronize() - if TORCH_MEMORY_SAVER_AVAILABLE and torch_memory_saver._impl is not None: + device_utils.synchronize() + if torch_memory_saver._impl is not None: torch_memory_saver._impl._binary_wrapper.cdll.tms_set_interesting_region(True) deep_ep.Buffer.__init__ = new_init diff --git a/relax/backends/megatron/actor.py b/relax/backends/megatron/actor.py index b1d1b1bf..3e5d3a65 100644 --- a/relax/backends/megatron/actor.py +++ b/relax/backends/megatron/actor.py @@ -21,6 +21,7 @@ from relax.distributed.checkpoint_service.client.engine import create_client from relax.distributed.ray.train_actor import TrainRayActor +from relax.utils import device as device_utils from relax.utils import tracking_utils from relax.utils.async_utils import run from relax.utils.data.stream_dataloader import ( @@ -66,6 +67,32 @@ logger = logging.getLogger(__name__) +def _warmup_actor_reduce_scatter_once(rollout_id: int) -> None: + if rollout_id != 0: + return + + group = mpu.get_data_parallel_group(with_context_parallel=True) + world_size = dist.get_world_size(group=group) + device = torch.device("cuda", torch.cuda.current_device()) + output = torch.empty(8, device=device, dtype=torch.float32) + input_tensor = torch.ones(8 * world_size, device=device, dtype=torch.float32) + + logger.info( + "Running actor dummy reduce-scatter warmup for rollout_id=%s, rank=%s, group_world_size=%s", + rollout_id, + dist.get_rank(group=group), + world_size, + ) + try: + opts = dist.ReduceScatterOptions() + opts.reduceOp = dist.ReduceOp.SUM + group.reduce_scatter_tensor_coalesced([output], [input_tensor], opts).wait() + except AttributeError: + dist.reduce_scatter_tensor(output, input_tensor, op=dist.ReduceOp.SUM, group=group) + torch.cuda.synchronize() + logger.info("Actor dummy reduce-scatter warmup completed for rollout_id=%s", rollout_id) + + class MegatronTrainRayActor(TrainRayActor): def init( self, @@ -605,6 +632,13 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch) -> None: except Exception as e: logger.warning(f"Error triggering evaluation for rollout_id {rollout_id}: {e}") + # On the final training step the rollout component has already exited + # its main loop, so nothing else awaits the eval handler. Block here + # until eval finishes; otherwise the controller's atexit shutdown + # races with eval and tears down the SGLang engines mid-flight. + if is_train_done: + self._wait_for_previous_eval() + def compute_ref_log_prob(self, rollout_id: int) -> None: if self.args.use_routing_replay: os.environ["ROUTING_REPLAY_STAGE"] = "fallthrough" @@ -685,6 +719,8 @@ def train_async(self, rollout_id) -> None: if self.args.use_routing_replay: os.environ["ROUTING_REPLAY_STAGE"] = "replay_backward" + _warmup_actor_reduce_scatter_once(rollout_id) + logger.info(f"start to get rollout_id: {rollout_id} data from transfer queue for train step.") data_fields = [ "tokens", @@ -757,12 +793,19 @@ def train_async(self, rollout_id) -> None: logger.warning( f"Error during async weight update: {e}, maybe cause by rollout server failure. Will continue without async update for this step." ) + # On the final training step the rollout component has already + # exited its main loop, so the eval just triggered above will not + # be awaited anywhere. Block until it finishes; otherwise the + # controller's atexit shutdown races with eval and tears down the + # SGLang engines mid-flight. + if (rollout_id + 1) == self.args.num_rollout: + self._wait_for_previous_eval() if self.args.use_routing_replay: RoutingReplay.clear_all() total_lengths = rollout_data["total_lengths"] all_total_lengths = [None] * mpu.get_data_parallel_world_size(with_context_parallel=False) dist.all_gather_object( - all_total_lengths, total_lengths, group=mpu.get_data_parallel_group(with_context_parallel=True) + all_total_lengths, total_lengths, group=mpu.get_data_parallel_group(with_context_parallel=False) ) all_total_lengths = sum(all_total_lengths, []) # flatten Timer().seq_lens = all_total_lengths @@ -921,7 +964,7 @@ def _check_services_health(self) -> tuple[bool, bool]: flags = torch.tensor( [int(rollout_only), int(actor_fwd_only)], dtype=torch.int32, - device=torch.cuda.current_device(), + device=device_utils.make_current_torch_device(), ) dist.all_reduce(flags, op=dist.ReduceOp.MAX, group=get_gloo_group()) rollout_only = bool(flags[0].item()) @@ -1031,11 +1074,19 @@ def load_other_checkpoint(self, model_tag: str, path: str) -> None: self._active_model_tag = model_tag def all_consumed(self, task_name, rollout_id): - if mpu.get_tensor_model_parallel_rank() == 0 and mpu.get_pipeline_model_parallel_rank() == 0: + # Only (TP=0, PP=0, CP=0) queries the transfer queue; otherwise different cp_ranks + # may observe different consumption status due to concurrent fetches and diverge, + # leaving some ranks idle while others enter the next collective and hang. + if ( + mpu.get_tensor_model_parallel_rank() == 0 + and mpu.get_pipeline_model_parallel_rank() == 0 + and mpu.get_context_parallel_rank() == 0 + ): status = [run(self.data_system_client.async_check_consumption_status(task_name, f"train_{rollout_id}"))] else: status = [True] - status = torch.tensor(status, device=torch.cuda.current_device()) + status = torch.tensor(status, device=device_utils.make_current_torch_device()) + dist.broadcast(status, group=mpu.get_context_parallel_group(), group_src=0) dist.broadcast(status, group=mpu.get_tensor_model_parallel_group(), group_src=0) dist.broadcast(status, group=mpu.get_pipeline_model_parallel_group(), group_src=0) diff --git a/relax/backends/megatron/arguments.py b/relax/backends/megatron/arguments.py index 7ad1ae56..a2fd4f30 100644 --- a/relax/backends/megatron/arguments.py +++ b/relax/backends/megatron/arguments.py @@ -2,9 +2,16 @@ from megatron.training.arguments import parse_args as _megatron_parse_args from megatron.training.arguments import validate_args as _megatron_validate_args -from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding + + +try: + from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding as vocab_size_with_padding +except ModuleNotFoundError: + from megatron.core.tokenizers.utils.build_tokenizer import vocab_size_with_padding + from transformers import AutoConfig +from relax.utils import device as device_utils from relax.utils.logging_utils import get_logger @@ -17,17 +24,18 @@ def validate_args(args): """Run megatron's own validate_args plus slime-specific megatron validations.""" - import torch - - if not torch.cuda.is_available(): + if not device_utils.is_available(): from unittest.mock import patch - class _CudaProperty: + class _DeviceProperty: major = 9 minor = 0 + # Megatron internally calls torch.cuda.get_device_properties / get_device_capability. + # When no real device is available, device_utils.get_device_name() returns "cpu", + # so we must patch torch.cuda specifically — that's what Megatron actually invokes. with ( - patch("torch.cuda.get_device_properties", return_value=_CudaProperty()), + patch("torch.cuda.get_device_properties", return_value=_DeviceProperty()), patch("torch.cuda.get_device_capability", return_value=(9, 0)), ): _megatron_validate_args(args) @@ -53,6 +61,13 @@ class _CudaProperty: "decoder_first_pipeline_num_layers and decoder_last_pipeline_num_layers should be None when " "pipeline_model_parallel_size is 1." ) + + # Megatron-Bridge requires --calculate-per-token-loss when context parallelism is enabled. + # See https://github.com/NVIDIA-NeMo/Megatron-Bridge + if args.context_parallel_size > 1: + assert args.calculate_per_token_loss, ( + "--calculate-per-token-loss must be set when context_parallel_size > 1 (required by Megatron-Bridge)." + ) return args @@ -82,15 +97,19 @@ def equal(x, y): if hasattr(hf_config, "text_config"): hf_config = hf_config.text_config - for hf_config_name, megatron_config_name, compare_fn in [ - ("hidden_size", "hidden_size", equal), - ("num_attention_heads", "num_attention_heads", equal), - ("num_hidden_layers", "num_layers", equal), - ("intermediate_size", "ffn_hidden_size", equal), - ("tie_word_embeddings", "untie_embeddings_and_output_weights", lambda x, y: not x == y), - ("rms_norm_eps", "norm_epsilon", equal), - ("rope_theta", "rotary_base", equal), - ]: + for hf_config_name, megatron_config_name, compare_fn in ( + [ + ("hidden_size", "hidden_size", equal), + ("num_attention_heads", "num_attention_heads", equal), + ("num_hidden_layers", "num_layers", equal), + ("intermediate_size", "ffn_hidden_size", equal), + ("tie_word_embeddings", "untie_embeddings_and_output_weights", lambda x, y: not x == y), + ("rope_theta", "rotary_base", equal), + ] + + [("rms_norm_eps", "norm_epsilon", equal)] + if hasattr(args, "norm_epsilon") + else [("rms_norm_eps", "layernorm_epsilon", equal)] + ): if hasattr(hf_config, hf_config_name): if not compare_fn(getattr(hf_config, hf_config_name), getattr(args, megatron_config_name)): errors.append( @@ -118,7 +137,7 @@ def _set_default_megatron_args(args): args.rope_type = "yarn" if args.multi_latent_attention else "rope" if args.vocab_size and not args.padded_vocab_size: - args.padded_vocab_size = _vocab_size_with_padding(args.vocab_size, args) + args.padded_vocab_size = vocab_size_with_padding(args.vocab_size, args) if not args.tokenizer_model and not args.tokenizer_type: logger.info("--tokenizer-model not set, use --hf-checkpoint as tokenizer model.") diff --git a/relax/backends/megatron/cp_utils.py b/relax/backends/megatron/cp_utils.py index deb43ba8..f38855d1 100644 --- a/relax/backends/megatron/cp_utils.py +++ b/relax/backends/megatron/cp_utils.py @@ -6,11 +6,35 @@ from megatron.core import mpu +def maybe_padded_total_lengths( + total_lengths: list[int], + qkv_format: str, + is_vl_model: bool, +) -> list[int] | None: + """Per-sample tp*cp*2 padded lengths for the bridge VL+CP+thd path. + + Bridge's `preprocess_packed_seqs` (Qwen3-VL et al.) pads each sample to a + multiple of `tp*cp*2` before zigzag-splitting along CP, so the local logits + returned by the bridge index per-sample chunks at `padded_len // (2*cp)`. + Relax helpers that re-derive those chunks must agree. + + Returns None for non-VL/non-CP/non-thd paths so callers fall back to the + standard `ceil(total_length / (2*cp))` formula. + """ + cp_size = mpu.get_context_parallel_world_size() + if not (is_vl_model and cp_size > 1 and qkv_format == "thd"): + return None + tp_size = mpu.get_tensor_model_parallel_world_size() + align = tp_size * cp_size * 2 + return [(t + align - 1) // align * align for t in total_lengths] + + def get_logits_and_tokens_offset_with_cp( total_length: int, response_length: int, qkv_format: str = "thd", max_seq_len: int | None = None, + padded_total_length: int | None = None, ): """All offsets start from the begining of the prompt.""" cp_rank = mpu.get_context_parallel_rank() @@ -18,7 +42,13 @@ def get_logits_and_tokens_offset_with_cp( assert cp_size > 1 prompt_length = total_length - response_length - if qkv_format == "thd": + if padded_total_length is not None: + # Bridge VL+CP+thd: per-sample padded length is already aligned to tp*cp*2. + assert padded_total_length % (2 * cp_size) == 0, ( + f"padded_total_length={padded_total_length} not divisible by 2*cp={2 * cp_size}" + ) + chunk_size = padded_total_length // (2 * cp_size) + elif qkv_format == "thd": chunk_size = (total_length + 2 * cp_size - 1) // (2 * cp_size) else: assert max_seq_len is not None, "max_seq_len must be provided for qkv_format=bshd" @@ -55,6 +85,7 @@ def get_sum_of_sample_mean( calculate_per_token_loss: bool = False, qkv_format: str = "thd", max_seq_lens: list[int] | None = None, + padded_total_lengths: list[int] | None = None, ) -> Callable[[torch.Tensor], torch.Tensor]: """Calculate correct sample mean for CP.""" cp_size = mpu.get_context_parallel_world_size() @@ -84,9 +115,10 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: zip(total_lengths, response_lengths, loss_masks, strict=False) ): max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None + padded_total_length = padded_total_lengths[i] if padded_total_lengths is not None else None prompt_length = total_length - response_length _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp( - total_length, response_length, qkv_format, max_seq_len + total_length, response_length, qkv_format, max_seq_len, padded_total_length ) loss_mask_0 = loss_mask[tokens_offset[0][0] - prompt_length : tokens_offset[0][1] - prompt_length] loss_mask_1 = loss_mask[tokens_offset[1][0] - prompt_length : tokens_offset[1][1] - prompt_length] @@ -116,7 +148,12 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token -def all_gather_with_cp(tensor: torch.Tensor, total_length: int, response_length: int) -> torch.Tensor: +def all_gather_with_cp( + tensor: torch.Tensor, + total_length: int, + response_length: int, + padded_total_length: int | None = None, +) -> torch.Tensor: """Gather tensors across all ranks in the context parallel group. The first dimension of the output tensor will be the `response_length`. @@ -127,7 +164,9 @@ def all_gather_with_cp(tensor: torch.Tensor, total_length: int, response_length: if cp_size == 1: return tensor - _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp(total_length, response_length) + _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp( + total_length, response_length, padded_total_length=padded_total_length + ) prompt_length = total_length - response_length @@ -218,6 +257,7 @@ def slice_log_prob_with_cp( response_length: int, qkv_format: str = "thd", max_token_len: int | None = None, + padded_total_length: int | None = None, ) -> list[float] | torch.Tensor: assert len(log_prob) == response_length @@ -228,7 +268,7 @@ def slice_log_prob_with_cp( prompt_length = total_length - response_length _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp( - total_length, response_length, qkv_format, max_token_len + total_length, response_length, qkv_format, max_token_len, padded_total_length ) chunk_1 = log_prob[logits_offset[0][0] - (prompt_length - 1) : logits_offset[0][1] - (prompt_length - 1)] diff --git a/relax/backends/megatron/data.py b/relax/backends/megatron/data.py index 7ce42fe7..06ac5e9f 100644 --- a/relax/backends/megatron/data.py +++ b/relax/backends/megatron/data.py @@ -12,6 +12,7 @@ from megatron.core.packed_seq_params import PackedSeqParams from torch.nn.utils.rnn import pad_sequence +from relax.utils import device as device_utils from relax.utils import tracking_utils from relax.utils.data.data import get_minimum_num_micro_batch_size from relax.utils.data.seqlen_balancing import get_seqlen_balanced_partitions @@ -22,7 +23,7 @@ from relax.utils.training.flops_utils import calculate_fwd_flops from relax.utils.types import RolloutBatch -from .cp_utils import get_sum_of_sample_mean, slice_with_cp +from .cp_utils import get_sum_of_sample_mean, maybe_padded_total_lengths, slice_with_cp logger = get_logger(__name__) @@ -151,11 +152,59 @@ def get_batch( if qkv_format == "bshd": max_seqlen = batch["max_seq_lens"][0] assert max([t.size(0) for t in tokens]) <= max_seqlen + + # For VL models with CP > 1, Bridge expects UNSPLIT tokens (it handles CP + # splitting internally after vision embedding). Save padded-but-unsplit + # tokens so model.py can pass them to Bridge instead of the CP-split ones. + if cp_size > 1: + chunk_size = (max_seqlen + 2 * cp_size - 1) // (2 * cp_size) + padded_len = 2 * cp_size * chunk_size + unsplit = [F.pad(t, (0, padded_len - t.size(0)), value=pad_token_id) for t in tokens] + batch["unsplit_tokens"] = torch.stack(unsplit) + tokens = [slice_with_cp(t, pad_token_id, qkv_format, max_seqlen) for t in tokens] tokens = torch.stack(tokens) packed_seq_params = None elif qkv_format == "thd": + # VL + CP > 1: bridge's Qwen3VLModel.forward expects per-sample + # BSHD-padded input_ids + attention_mask, and re-derives the THD + # packing internally with align_size = tp*cp*2. Provide unsplit + # inputs and a matching packed_seq_params so the caller-side cu_seqlens + # agrees with what the bridge derives from attention_mask. + # Mirrors verl's build_vlm_attn_mask_thd + preprocess_thd_engine. + is_vl_model = batch.get("multimodal_train_inputs") is not None + if is_vl_model and cp_size > 1: + tp_size = mpu.get_tensor_model_parallel_world_size() + align_size = tp_size * cp_size * 2 + device = device_utils.make_current_torch_device() + + seqlens = torch.tensor([t.size(0) for t in tokens], dtype=torch.int32, device=device) + seqlens_padded = (seqlens + align_size - 1) // align_size * align_size + cu_seqlens_padded = torch.zeros(len(tokens) + 1, dtype=torch.int32, device=device) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_padded, dim=0) + max_seqlen_padded = int(seqlens_padded.max().item()) + + unsplit_tokens = pad_sequence(tokens, batch_first=True, padding_value=pad_token_id) + unsplit_attention_mask = torch.zeros_like(unsplit_tokens, dtype=torch.bool) + for i, s in enumerate(seqlens.tolist()): + unsplit_attention_mask[i, :s] = True + + batch["unsplit_tokens"] = unsplit_tokens + batch["unsplit_attention_mask"] = unsplit_attention_mask + batch["vlm_packed_seq_params"] = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + cu_seqlens_kv=cu_seqlens_padded, + max_seqlen_q=max_seqlen_padded, + max_seqlen_kv=max_seqlen_padded, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + # Per-sample tp*cp*2-aligned lengths consumed by loss helpers so + # their per-sample chunking matches bridge's preprocess_packed_seqs. + batch["padded_total_lengths"] = seqlens_padded.tolist() + if allgather_cp: # DSA mode: concatenate all sequences first, then slice once with CP. # We also pad the *global* concatenated stream to make per-rank chunks equal. @@ -173,7 +222,9 @@ def get_batch( tokens = F.pad(tokens, (0, pad), value=pad_token_id) cu_seqlens_list.append(cu_seqlens_list[-1] + pad) - cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int, device=torch.cuda.current_device()) + cu_seqlens = torch.tensor( + cu_seqlens_list, dtype=torch.int, device=device_utils.make_current_torch_device() + ) tokens = tokens.chunk(cp_size, dim=0)[cp_rank] else: tokens = [slice_with_cp(t, pad_token_id, qkv_format) for t in tokens] @@ -191,7 +242,9 @@ def get_batch( cu_seqlens.append(cu_seqlens[-1] + pad) # thd requires the cu_seqlens to be of the origin length - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int).cuda() * cp_size + cu_seqlens = ( + torch.tensor(cu_seqlens, dtype=torch.int).to(device_utils.make_current_torch_device()) * cp_size + ) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() packed_seq_params = PackedSeqParams( @@ -440,7 +493,9 @@ def get_data_iterator( # across DP ranks so that all ranks execute the same number of training steps # (required by collective operations in the training loop). if getattr(args, "balance_data", False): - steps_tensor = torch.tensor([num_steps_per_rollout], dtype=torch.int, device=torch.cuda.current_device()) + steps_tensor = torch.tensor( + [num_steps_per_rollout], dtype=torch.int, device=device_utils.make_current_torch_device() + ) dist.all_reduce(steps_tensor, op=dist.ReduceOp.MAX, group=dp_group) num_steps_per_rollout = steps_tensor.item() @@ -472,7 +527,9 @@ def _generate_data_iterator(rollout_data, micro_batch_size, micro_batch_indices= get_minimum_num_micro_batch_size(samples[start:end], args.max_tokens_per_gpu * cp_size) ) - num_microbatches = torch.tensor(num_microbatches, dtype=torch.int, device=torch.cuda.current_device()) + num_microbatches = torch.tensor( + num_microbatches, dtype=torch.int, device=device_utils.make_current_torch_device() + ) dist.all_reduce(num_microbatches, op=dist.ReduceOp.MAX, group=dp_group) if vpp_size > 1: @@ -541,6 +598,11 @@ def log_rollout_data( loss_masks = rollout_data["loss_masks"] total_lengths = rollout_data["total_lengths"] max_seq_lens = rollout_data.get("max_seq_lens", None) + padded_total_lengths = maybe_padded_total_lengths( + total_lengths, + args.qkv_format, + rollout_data.get("multimodal_train_inputs") is not None, + ) # OPD dynamic metric: overlap ratio on top-k token sets. student_topk_ids = rollout_data.get("topk_token_ids", None) @@ -574,6 +636,7 @@ def log_rollout_data( loss_masks_t, qkv_format=args.qkv_format, max_seq_lens=max_seq_lens, + padded_total_lengths=padded_total_lengths, ) overlap_ratio_value = cp_size * sum_of_sample_mean(overlap_ratio_flat) / len(loss_masks_t) log_dict["opd_overlap_ratio"] = overlap_ratio_value.item() @@ -619,6 +682,7 @@ def log_rollout_data( loss_masks, qkv_format=args.qkv_format, max_seq_lens=max_seq_lens, + padded_total_lengths=padded_total_lengths, ) val = cp_size * sum_of_sample_mean(val) / len(loss_masks) else: @@ -693,12 +757,22 @@ def quantile(total_value, n_quantiles, data) -> dict: correct_total_lengths = [] correct_loss_masks = [] correct_entropy = [] + correct_padded_total_lengths_full = maybe_padded_total_lengths( + total_lengths, + args.qkv_format, + rollout_data.get("multimodal_train_inputs") is not None, + ) + correct_padded_total_lengths: list[int] | None = ( + [] if correct_padded_total_lengths_full is not None else None + ) for i, raw_reward in enumerate(raw_rewards): if raw_reward == 1: correct_response_lengths.append(response_lengths[i]) correct_total_lengths.append(total_lengths[i]) correct_loss_masks.append(loss_masks[i]) correct_entropy.append(-rollout_data["log_probs"][i]) + if correct_padded_total_lengths is not None: + correct_padded_total_lengths.append(correct_padded_total_lengths_full[i]) num_correct_responses = len(correct_total_lengths) rollout_data["correct_response_lengths"] = correct_response_lengths correct_response_length_percentile = quantile( @@ -708,7 +782,10 @@ def quantile(total_value, n_quantiles, data) -> dict: rollout_data[f"correct_length/{p}"] = [val] * num_correct_responses if len(correct_entropy) > 0: sum_of_sample_mean = get_sum_of_sample_mean( - correct_total_lengths, correct_response_lengths, correct_loss_masks + correct_total_lengths, + correct_response_lengths, + correct_loss_masks, + padded_total_lengths=correct_padded_total_lengths, ) correct_entropy = sum_of_sample_mean(torch.cat(correct_entropy, dim=0)) rollout_data["correct_entropy"] = [correct_entropy.item()] * num_correct_responses diff --git a/relax/backends/megatron/initialize.py b/relax/backends/megatron/initialize.py index 563d730e..a33129c5 100644 --- a/relax/backends/megatron/initialize.py +++ b/relax/backends/megatron/initialize.py @@ -51,12 +51,19 @@ def _initialize_distributed(args, get_embedding_ranks=None, get_position_embeddi order="tp-cp-ep-dp-pp" if not args.use_tp_pp_dp_mapping else "tp-cp-ep-pp-dp", get_embedding_ranks=get_embedding_ranks, get_position_embedding_ranks=get_position_embedding_ranks, - create_gloo_process_groups=args.enable_gloo_process_groups, + create_gloo_process_groups=args.use_gloo_process_groups, ) def init(args): set_args(args) + + if getattr(args, "disable_jit_fuser", False): + from megatron.core.jit import disable_jit_fuser + + disable_jit_fuser() + logger.info("JIT fuser disabled (torch.compile → no-op).") + if args.enable_experimental: logger.info("Enable megatron experimental") set_experimental_flag(True) diff --git a/relax/backends/megatron/kernels/int4_qat/setup.py b/relax/backends/megatron/kernels/int4_qat/setup.py index b8bfc7dc..2fc5ba13 100644 --- a/relax/backends/megatron/kernels/int4_qat/setup.py +++ b/relax/backends/megatron/kernels/int4_qat/setup.py @@ -6,6 +6,8 @@ # Get CUDA arch list +# NOTE: This setup script is CUDA-only as it compiles .cu kernel files via CUDAExtension. +# Non-CUDA backends (NPU, XPU, PPU) should provide their own kernel implementations. arch_list = [] if torch.cuda.is_available(): for i in range(torch.cuda.device_count()): diff --git a/relax/backends/megatron/loss.py b/relax/backends/megatron/loss.py index 47530f03..72967e66 100644 --- a/relax/backends/megatron/loss.py +++ b/relax/backends/megatron/loss.py @@ -28,6 +28,7 @@ all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean, + maybe_padded_total_lengths, slice_log_prob_with_cp, ) @@ -40,6 +41,7 @@ def get_responses( total_lengths: list[int], response_lengths: list[int], max_seq_lens: list[int] | None = None, + padded_total_lengths: list[int] | None = None, ) -> Iterator[tuple[torch.Tensor, torch.Tensor]]: """Yield response-aligned `(logits_chunk, tokens_chunk)` pairs per sample. @@ -84,6 +86,7 @@ def get_responses( zip(unconcat_tokens, total_lengths, response_lengths, strict=False) ): max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None + padded_total_length = padded_total_lengths[i] if padded_total_lengths is not None else None if cp_size == 1: if qkv_format == "bshd": @@ -120,7 +123,7 @@ def get_responses( else: # TODO: this is super ugly... do better abstraction. chunk_size, chunks_offset, logits_offset, tokens_offset = get_logits_and_tokens_offset_with_cp( - total_length, response_length, qkv_format, max_seq_len + total_length, response_length, qkv_format, max_seq_len, padded_total_length ) logits_0, logits_1 = logits[end : end + chunk_size], logits[end + chunk_size : end + 2 * chunk_size] @@ -151,6 +154,7 @@ def _allgather_cp_redistribute( total_lengths: list[int], response_lengths: list[int], max_seq_lens: list[int] | None = None, + padded_total_lengths: list[int] | None = None, ) -> None: """Redistribute response tensors from allgather-CP layout to zigzag ring- attn layout. @@ -217,8 +221,16 @@ def _allgather_cp_redistribute( zip(all_cat.split(response_lengths, dim=0), total_lengths, response_lengths, strict=False) ): max_seq_len = max_seq_lens[idx] if max_seq_lens is not None else None + padded_total_length = padded_total_lengths[idx] if padded_total_lengths is not None else None new_values.append( - slice_log_prob_with_cp(full_resp, total_length, response_length, args.qkv_format, max_seq_len) + slice_log_prob_with_cp( + full_resp, + total_length, + response_length, + args.qkv_format, + max_seq_len, + padded_total_length, + ) ) res[key] = new_values @@ -236,6 +248,7 @@ def get_log_probs_and_entropy( topk_k: int | None = None, non_loss_data: bool = True, max_seq_lens: list[int] | None = None, + padded_total_lengths: list[int] | None = None, ) -> tuple[torch.Tensor, dict[str, list[torch.Tensor]]]: """Compute per-token log-probabilities (and optionally entropy) on responses. @@ -274,6 +287,7 @@ def get_log_probs_and_entropy( total_lengths=total_lengths, response_lengths=response_lengths, max_seq_lens=max_seq_lens, + padded_total_lengths=padded_total_lengths, ): log_prob, entropy = calculate_log_probs_and_entropy( logits_chunk, @@ -307,6 +321,7 @@ def get_log_probs_and_entropy( total_lengths=total_lengths, response_lengths=response_lengths, max_seq_lens=max_seq_lens, + padded_total_lengths=padded_total_lengths, ) return torch.empty((0,), device=logits.device), res @@ -322,6 +337,7 @@ def get_values( with_entropy: bool = False, non_loss_data: bool = True, max_seq_lens: list[int] | None = None, + padded_total_lengths: list[int] | None = None, ) -> tuple[torch.Tensor, dict[str, list[torch.Tensor]]]: """Extract per-token value predictions over response tokens. @@ -352,6 +368,7 @@ def get_values( total_lengths=total_lengths, response_lengths=response_lengths, max_seq_lens=max_seq_lens, + padded_total_lengths=padded_total_lengths, ): assert logits_chunk.size(-1) == 1, f"{logits_chunk.shape}" value_list.append(logits_chunk.squeeze(-1)) @@ -368,6 +385,7 @@ def get_values( total_lengths=total_lengths, response_lengths=response_lengths, max_seq_lens=max_seq_lens, + padded_total_lengths=padded_total_lengths, ) return torch.empty((0,), device=logits.device), res @@ -444,6 +462,11 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) loss_masks: list[torch.Tensor] = rollout_data.get("loss_masks") total_lengths: list[int] = rollout_data.get("total_lengths") max_seq_lens: list[int] | None = rollout_data.get("max_seq_lens", None) + padded_total_lengths: list[int] | None = maybe_padded_total_lengths( + total_lengths, + args.qkv_format, + rollout_data.get("multimodal_train_inputs") is not None, + ) # return when not the last pp stage. if not mpu.is_pipeline_last_stage(): @@ -539,9 +562,10 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) response_len = response_lengths[i] prompt_len = total_len - response_len max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None + padded_total_length = padded_total_lengths[i] if padded_total_lengths is not None else None _, _, _, token_offsets = get_logits_and_tokens_offset_with_cp( - total_len, response_len, args.qkv_format, max_seq_len + total_len, response_len, args.qkv_format, max_seq_len, padded_total_length ) # Convert global offsets to response-space offsets @@ -678,6 +702,7 @@ def policy_loss_function( response_lengths = batch["response_lengths"] total_lengths = batch["total_lengths"] max_seq_lens = batch.get("max_seq_lens", None) + padded_total_lengths = batch.get("padded_total_lengths", None) _, log_probs_and_entropy = get_log_probs_and_entropy( logits, @@ -687,6 +712,7 @@ def policy_loss_function( response_lengths=response_lengths, with_entropy=True, max_seq_lens=max_seq_lens, + padded_total_lengths=padded_total_lengths, ) log_probs = log_probs_and_entropy["log_probs"] @@ -697,16 +723,20 @@ def policy_loss_function( full_log_probs = None full_old_log_probs = None if need_full_log_probs: + if padded_total_lengths is None: + padded_iter = [None] * len(log_probs) + else: + padded_iter = padded_total_lengths full_log_probs = [ - all_gather_with_cp(log_prob, total_length, response_length) - for log_prob, total_length, response_length in zip( - log_probs, total_lengths, response_lengths, strict=False + all_gather_with_cp(log_prob, total_length, response_length, padded_total_length) + for log_prob, total_length, response_length, padded_total_length in zip( + log_probs, total_lengths, response_lengths, padded_iter, strict=False ) ] full_old_log_probs = [ - all_gather_with_cp(old_log_prob, total_length, response_length) - for old_log_prob, total_length, response_length in zip( - old_log_probs, total_lengths, response_lengths, strict=False + all_gather_with_cp(old_log_prob, total_length, response_length, padded_total_length) + for old_log_prob, total_length, response_length, padded_total_length in zip( + old_log_probs, total_lengths, response_lengths, padded_iter, strict=False ) ] @@ -788,6 +818,7 @@ def policy_loss_function( args.calculate_per_token_loss, args.qkv_format, max_seq_lens, + padded_total_lengths, ) # Determine pg_loss reducer: use custom if specified, otherwise default @@ -908,6 +939,7 @@ def value_loss_function( total_lengths=batch["total_lengths"], response_lengths=batch["response_lengths"], max_seq_lens=batch.get("max_seq_lens", None), + padded_total_lengths=batch.get("padded_total_lengths", None), ) values = torch.cat([value.flatten() for value in values["values"]], dim=0) @@ -967,6 +999,7 @@ def sft_loss_function( response_lengths=response_lengths, with_entropy=False, max_seq_lens=batch.get("max_seq_lens", None), + padded_total_lengths=batch.get("padded_total_lengths", None), ) log_probs = log_probs_and_entropy["log_probs"] @@ -1024,6 +1057,7 @@ def loss_function( args.calculate_per_token_loss, args.qkv_format, batch.get("max_seq_lens", None), + batch.get("padded_total_lengths", None), ) match args.loss_type: diff --git a/relax/backends/megatron/model.py b/relax/backends/megatron/model.py index e892ec8b..71af4b00 100644 --- a/relax/backends/megatron/model.py +++ b/relax/backends/megatron/model.py @@ -153,7 +153,7 @@ def setup_model_and_optimizer( optimizer = get_megatron_optimizer( config=config, model_chunks=model, - use_gloo_process_groups=args.enable_gloo_process_groups, + use_gloo_process_groups=args.use_gloo_process_groups, ) opt_param_scheduler = get_optimizer_param_scheduler(args, optimizer) return model, optimizer, opt_param_scheduler @@ -257,14 +257,39 @@ def forward_step( packed_seq_params = batch["packed_seq_params"] total_lengths = batch["total_lengths"] response_lengths = batch["response_lengths"] + + is_vl_model = batch.get("multimodal_train_inputs", None) is not None + mm_kwargs = batch["multimodal_train_inputs"] if is_vl_model else {} + + # VL + CP > 1: pass unsplit tokens so Bridge handles CP split after + # vision embedding (aligns with Bridge's qwen3_vl_step.py contract). + if is_vl_model and "unsplit_tokens" in batch: + forward_input_ids = batch["unsplit_tokens"] + forward_packed_seq_params = None + else: + forward_input_ids = tokens + forward_packed_seq_params = packed_seq_params + + # thd VL+CP: bridge needs per-sample attention_mask + matching thd + # packed_seq_params (align_size = tp*cp*2). loss_mask is None because + # labels=None means GPTModel won't run internal loss; Relax's loss is + # computed externally from full_loss_masks. + if is_vl_model and "vlm_packed_seq_params" in batch: + forward_attention_mask = batch["unsplit_attention_mask"] + forward_packed_seq_params = batch["vlm_packed_seq_params"] + forward_loss_mask = None + else: + forward_attention_mask = None + forward_loss_mask = batch["full_loss_masks"] + output_tensor = model( - input_ids=tokens, + input_ids=forward_input_ids, position_ids=None, - attention_mask=None, + attention_mask=forward_attention_mask, labels=None, - packed_seq_params=packed_seq_params, - loss_mask=batch["full_loss_masks"], - **(batch["multimodal_train_inputs"] if batch.get("multimodal_train_inputs", None) is not None else {}), + packed_seq_params=forward_packed_seq_params, + loss_mask=forward_loss_mask, + **mm_kwargs, ) return output_tensor, partial( @@ -275,6 +300,7 @@ def forward_step( response_lengths=response_lengths, with_entropy=args.use_rollout_entropy, max_seq_lens=batch.get("max_seq_lens", None), + padded_total_lengths=batch.get("padded_total_lengths", None), ) # Turn on evaluation mode which disables dropout. @@ -429,19 +455,31 @@ def forward_step( loss_mask=batch["full_loss_masks"], ) else: + is_vl_model = batch.get("multimodal_train_inputs", None) is not None + use_unsplit = is_vl_model and "unsplit_tokens" in batch + forward_kwargs = { - "input_ids": batch["tokens"], + "input_ids": batch["unsplit_tokens"] if use_unsplit else batch["tokens"], "position_ids": None, "attention_mask": None, "labels": None, - "packed_seq_params": batch["packed_seq_params"], + "packed_seq_params": None if use_unsplit else batch["packed_seq_params"], "loss_mask": batch["full_loss_masks"], } + # thd VL+CP: bridge needs per-sample attention_mask + matching thd + # packed_seq_params (align_size = tp*cp*2). loss_mask is None + # because labels=None means GPTModel won't run internal loss; + # Relax's loss is computed externally from full_loss_masks. + if is_vl_model and "vlm_packed_seq_params" in batch: + forward_kwargs["attention_mask"] = batch["unsplit_attention_mask"] + forward_kwargs["packed_seq_params"] = batch["vlm_packed_seq_params"] + forward_kwargs["loss_mask"] = None + if args.enable_mtp_training: forward_kwargs["mtp_kwargs"] = {"mtp_labels": batch["tokens"]} - if batch.get("multimodal_train_inputs", None) is not None: + if is_vl_model: forward_kwargs.update(batch["multimodal_train_inputs"]) output_tensor = model(**forward_kwargs) diff --git a/relax/backends/megatron/model_provider.py b/relax/backends/megatron/model_provider.py index 770309a4..21815a7a 100644 --- a/relax/backends/megatron/model_provider.py +++ b/relax/backends/megatron/model_provider.py @@ -86,6 +86,87 @@ def forward( return logits, None +# CP-PROBE: one-shot forward-pre-hook on the first attention module to verify that +# context parallelism actually splits the sequence dimension at the attention input. +# Compare seq_len across CP=1 vs CP=2 runs — it must halve. Remove after verifying. +_CP_PROBE_INSTALLED = False + + +def _install_cp_probe(model: torch.nn.Module) -> None: + global _CP_PROBE_INSTALLED + if _CP_PROBE_INSTALLED: + return + + from megatron.core import mpu + + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + tp_rank = mpu.get_tensor_model_parallel_rank() + state = {"n": 0} + + target_classes = ( + "DotProductAttention", + "TEDotProductAttention", + "FusedAttention", + "FlashAttention", + ) + + def hook(module, args, kwargs): + if state["n"] >= 2 or tp_rank != 0: + return + shapes: dict[str, object] = {} + for name in ( + "query", + "key", + "value", + "q", + "k", + "v", + "hidden_states", + "query_layer", + "key_layer", + "value_layer", + ): + t = kwargs.get(name) + if torch.is_tensor(t): + shapes[name] = tuple(t.shape) + for i, t in enumerate(args): + if torch.is_tensor(t): + shapes[f"arg{i}"] = tuple(t.shape) + for name in ("cu_seqlens_q", "cu_seqlens_kv"): + t = kwargs.get(name) + if torch.is_tensor(t): + shapes[name] = t.tolist() # one-shot sync, OK for probe + logger.debug(f"[CP-PROBE] cp_rank={cp_rank}/{cp_size} module={type(module).__name__} shapes={shapes}") + state["n"] += 1 + + skip_prefixes = ("vision_model", "visual", "vit", "image_encoder", "projector", "audio") + + def is_llm_backbone(n: str) -> bool: + return not any(p in n for p in skip_prefixes) + + matches = [(n, m) for n, m in model.named_modules() if type(m).__name__ in target_classes] + llm_matches = [(n, m) for n, m in matches if is_llm_backbone(n)] + chosen = llm_matches or matches # fallback to vision if no LLM backbone in this stage + + if chosen: + name, m = chosen[0] + m.register_forward_pre_hook(hook, with_kwargs=True) + logger.debug( + f"[CP-PROBE] hook installed on '{name}' ({type(m).__name__}) " + f"cp_size={cp_size} cp_rank={cp_rank} " + f"(total_attn_modules={len(matches)}, llm_backbone={len(llm_matches)})" + ) + _CP_PROBE_INSTALLED = True + return + + candidates = [(n, type(m).__name__) for n, m in model.named_modules() if "attention" in n.lower()][:8] + logger.warning( + f"[CP-PROBE] no attention module matched on this stage (cp_rank={cp_rank}); " + f"attention-like candidates: {candidates}" + ) + + def get_model_provider_func( args: argparse.Namespace, role: Literal["actor", "critic"] = "actor", @@ -108,6 +189,7 @@ def wrapped_model_provider( model.output_layer = LinearForLastLayer( input_size=model.config.hidden_size, output_size=1, config=model.config ) + _install_cp_probe(model) return model return wrapped_model_provider @@ -150,6 +232,7 @@ def wrapped_model_provider( "freeze_vision_projection", # https://github.com/redai-infra/Megatron-Bridge/commit/960bb5f18800d3e1fb9815e95daa185ab06c09ea "vision_dp_when_tp", + "calculate_per_token_loss", ] args_dict = vars(args) @@ -199,7 +282,14 @@ def wrapped_model_provider( pickle.dump(provider, f) logger.info(f"Provider config saved to {pkl_path}") - return provider.provide + original_provide = provider.provide + + def provide_with_cp_probe(*p_args, **p_kwargs): + model = original_provide(*p_args, **p_kwargs) + _install_cp_probe(model) + return model + + return provide_with_cp_probe def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage: int | None = None) -> GPTModel: """Builds the model. @@ -306,13 +396,14 @@ def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage if post_process and role == "critic": model.output_layer = LinearForLastLayer(input_size=config.hidden_size, output_size=1, config=config) + _install_cp_probe(model) return model return model_provider def wrap_model_provider_with_freeze(original_provider, args): - def wrapped_provider(pre_process=True, post_process=True, vp_stage=None): + def wrapped_provider(pre_process=True, post_process=True, vp_stage=None, **kwargs): sig = inspect.signature(original_provider) if "vp_stage" in sig.parameters: model = original_provider(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) diff --git a/relax/backends/megatron/weight_conversion/processors/quantizer_compressed_tensors.py b/relax/backends/megatron/weight_conversion/processors/quantizer_compressed_tensors.py index 7c1b5d20..77eabb4e 100644 --- a/relax/backends/megatron/weight_conversion/processors/quantizer_compressed_tensors.py +++ b/relax/backends/megatron/weight_conversion/processors/quantizer_compressed_tensors.py @@ -5,6 +5,8 @@ import torch import torch.nn as nn +from relax.utils import device as device_utils + try: import fake_int4_quant_cuda @@ -91,7 +93,7 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze awq_linear.bias = linear.bias.clone().half() pack_num = 32 // awq_linear.w_bit - device = torch.device(f"cuda:{torch.cuda.current_device()}") + device = device_utils.make_current_torch_device() repeat_scales = scales.to(device).t().repeat_interleave(group_size, 1) if isinstance(zeros, torch.Tensor): @@ -284,7 +286,7 @@ def quantize_params_compressed_tensors(converted_named_params, quantization_conf qw, s, zp = pack_layer(param, group_size, is_symmetric) qweight_name = name.replace(".weight", ".weight_packed") scale_name = name.replace(".weight", ".weight_scale") - weight_shape = torch.tensor(param.shape, dtype=torch.int32, device="cuda") + weight_shape = torch.tensor(param.shape, dtype=torch.int32, device=device_utils.get_device_name()) weight_shape_name = name.replace(".weight", ".weight_shape") if zp is not None: zp_name = name.replace(".weight", ".weight_zero_point") diff --git a/relax/backends/megatron/weight_update/common.py b/relax/backends/megatron/weight_update/common.py index 1be3e90f..e9e2dde3 100644 --- a/relax/backends/megatron/weight_update/common.py +++ b/relax/backends/megatron/weight_update/common.py @@ -100,7 +100,12 @@ def all_gather_param(args, name: str, param: torch.nn.Parameter) -> torch.Tensor param_partitions = [torch.empty_like(param.data) for _ in range(tp_size)] dist.all_gather(param_partitions, param.data, group=tp_group) partition_dim = param.partition_dim - assert param.partition_stride == 1, "partition_stride != 1 is not supported" + # NOTE: Megatron-LM (megatron/core/transformer/mlp.py) now explicitly sets partition_stride=2 + # for GLU/SwiGLU linear_fc1 layers to indicate interleaved [gate, up] TP layout. + # The rechunk logic below (chunk(2) + reorder) already handles this correctly, + # so we only assert stride==1 for non-GLU parameters. + if "linear_fc1" not in name: + assert param.partition_stride == 1, f"{param.partition_stride=} != 1 is not supported" # TODO: here we did an extra copy during concat, maybe merge this with convert_to_hf is better? # TODO: check only GLU is used. if "linear_fc1.weight" in name and "vision_model" not in name: @@ -165,7 +170,7 @@ def all_gather_params_async( param = direct_param else: # Process the gathered partitions (same logic as original all_gather_param) - assert partition_dim is not None, "partition_stride != 1 is not supported" + assert partition_dim is not None, "partition_dim must be set for TP-sharded params" # TODO: here we did an extra copy during concat, maybe merge this with convert_to_hf is better? # TODO: check only GLU is used. if "linear_fc1.weight" in info.name and "vision_model" not in info.name: diff --git a/relax/backends/megatron/weight_update/hf_weight_iterator_bridge.py b/relax/backends/megatron/weight_update/hf_weight_iterator_bridge.py index 831718f9..8cb1550e 100644 --- a/relax/backends/megatron/weight_update/hf_weight_iterator_bridge.py +++ b/relax/backends/megatron/weight_update/hf_weight_iterator_bridge.py @@ -22,7 +22,6 @@ def __init__(self, *args, **kwargs): self._bridge = AutoBridge.from_hf_pretrained(self.args.hf_checkpoint, trust_remote_code=True) def get_hf_weight_chunks(self, megatron_local_weights): - # TODO support quantization (e.g. modify megatron-bridge to provide megatron param name) renamed_megatron_local_weights = {strip_param_name_prefix(k): v for k, v in megatron_local_weights.items()} with megatron_bridge_utils.patch_megatron_model(self.model): conversion_tasks = self._bridge.get_conversion_tasks(self.model) @@ -31,7 +30,34 @@ def get_hf_weight_chunks(self, megatron_local_weights): named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks) def iter_quantized_named_weights(): - for hf_param_name, weight, megatron_param_name in named_weights: + hf_to_megatron_mapping = None + + for item in named_weights: + # Compatibility shim: old megatron-bridge yields 3-tuples + # ``(hf_param_name, weight, megatron_param_name)`` while + # the official bridge yields 2-tuples ``(hf_param_name, weight)``. + # Dispatch per-item so the same code path supports both. + if len(item) == 3: + hf_param_name, weight, megatron_param_name = item + elif len(item) == 2: + hf_param_name, weight = item + if hf_to_megatron_mapping is None: + hf_to_megatron_mapping = _build_hf_to_megatron_mapping(conversion_tasks) + # With PP > 1, export_hf_weights yields params from ALL + # PP ranks (via internal PP broadcast), but + # hf_to_megatron_mapping only contains params from this + # rank's conversion tasks. For remote PP rank params + # we fall back to hf_param_name — this is safe because + # remove_padding checks megatron-style names and + # quantize_params_fp8 regex won't match HF-style names. + megatron_param_name = hf_to_megatron_mapping.get(hf_param_name, hf_param_name) + else: + raise ValueError( + f"Unexpected named_weights tuple length {len(item)} from " + f"megatron-bridge.export_hf_weights(); expected 2 (new) or 3 (old). " + f"Item: {item!r}" + ) + processed_weight = postprocess_hf_param( args=self.args, megatron_param_name=megatron_param_name, @@ -56,7 +82,61 @@ def iter_quantized_named_weights(): ) +def _build_hf_to_megatron_mapping(conversion_tasks): + """Build a mapping from HF parameter names to megatron parameter names. + + Only relevant for the official megatron-bridge whose ``export_hf_weights`` + yields 2-tuples ``(hf_name, weight)`` and no longer carries the megatron + name in the tuple. We reconstruct the mapping by reading + ``task.mapping.hf_param`` — a pure metadata attribute that requires NO + collective communication. This is critical for PP > 1 where different + ranks hold different parameter subsets; calling ``megatron_to_hf()`` (which + contains PP broadcast / TP gather) with inconsistent tasks across ranks + would deadlock. + + ``mapping.hf_param`` is either: + - ``str``: simple 1-to-1 mappings (AutoMapping, DirectMapping, …) + - ``dict``: multi-output mappings (QKVMapping ``{"q","k","v"}``, + GatedMLPMapping ``{"gate","up"}``) + + This mirrors the approach shown in the official ``get_conversion_tasks`` + docstring of megatron-bridge's ``AutoBridge``. + + Note: with PP > 1, each rank only holds a subset of conversion tasks, so + the returned mapping is **incomplete** — it covers only the params that + belong to this PP rank. ``export_hf_weights`` yields params from ALL PP + ranks (via internal PP broadcast), so callers must handle missing keys + gracefully (e.g. fall back to the HF param name). + """ + hf_to_megatron_mapping = {} + + for task in conversion_tasks: + megatron_param_name = task.param_name + hf_param = task.mapping.hf_param + + if isinstance(hf_param, str): + hf_to_megatron_mapping[hf_param] = megatron_param_name + elif isinstance(hf_param, dict): + for hf_name in hf_param.values(): + hf_to_megatron_mapping[hf_name] = megatron_param_name + else: + raise TypeError( + f"Unexpected mapping.hf_param type {type(hf_param).__name__} " + f"for megatron param '{megatron_param_name}': {hf_param!r}" + ) + + return hf_to_megatron_mapping + + def _process_conversion_tasks(vanilla_conversion_tasks, new_weight_dict): + """Replace param_weight in each conversion task with the latest trained + weights. + + build_conversion_tasks() returns ``List[None | WeightConversionTask]`` + where None entries correspond to global params that have no mapping. We + filter them out here so that downstream consumers never see None. + """ + def _handle_one(task): if task.param_weight is None: return task @@ -70,7 +150,9 @@ def _handle_one(task): new_param_weight = new_param_weight.cuda() return dataclasses.replace(task, param_weight=new_param_weight) - return _MapWithLen(_handle_one, vanilla_conversion_tasks) + # Filter out None tasks (params with no mapping in build_conversion_tasks) + valid_tasks = [t for t in vanilla_conversion_tasks if t is not None] + return _MapWithLen(_handle_one, valid_tasks) class _MapWithLen: diff --git a/relax/backends/megatron/weight_update/hf_weight_iterator_direct.py b/relax/backends/megatron/weight_update/hf_weight_iterator_direct.py index 02da4654..32a382fb 100644 --- a/relax/backends/megatron/weight_update/hf_weight_iterator_direct.py +++ b/relax/backends/megatron/weight_update/hf_weight_iterator_direct.py @@ -7,6 +7,7 @@ from megatron.core import mpu from tqdm import tqdm +from relax.utils import device as device_utils from relax.utils.distributed_utils import get_gloo_group from relax.utils.types import ParamInfo @@ -55,13 +56,15 @@ def _get_megatron_full_params( if dist.get_rank() == info.src_rank: params.append( torch.nn.Parameter( - megatron_local_weights[info.name].to(device=torch.cuda.current_device(), non_blocking=True), + megatron_local_weights[info.name].to( + device=device_utils.make_current_torch_device(), non_blocking=True + ), requires_grad=False, ) ) else: - params.append(torch.empty(info.shape, dtype=info.dtype, device=torch.cuda.current_device())) - torch.cuda.synchronize() + params.append(torch.empty(info.shape, dtype=info.dtype, device=device_utils.make_current_torch_device())) + device_utils.synchronize() # broadcast params across pp ranks if pp_size > 1: diff --git a/relax/backends/megatron/weight_update/update_weight_from_distributed.py b/relax/backends/megatron/weight_update/update_weight_from_distributed.py index 2563c252..a255db01 100644 --- a/relax/backends/megatron/weight_update/update_weight_from_distributed.py +++ b/relax/backends/megatron/weight_update/update_weight_from_distributed.py @@ -12,12 +12,17 @@ from ray.actor import ActorHandle from tqdm import tqdm +from relax.utils import device as device_utils from relax.utils.distributed_utils import get_gloo_group, init_process_group +from relax.utils.logging_utils import get_logger from ..weight_conversion import convert_to_hf from .common import all_gather_param, named_params_and_buffers +logger = get_logger(__name__) + + class UpdateWeightFromDistributed: """Update distributed engines via NCCL. @@ -210,7 +215,7 @@ def _update_expert_bucket_weights_from_distributed( handles = [] for i, (_name, param) in enumerate(named_tensors): params = [ - torch.empty_like(param.data, device=torch.cuda.current_device()) + torch.empty_like(param.data, device=device_utils.make_current_torch_device()) for _ in range(mpu.get_expert_model_parallel_world_size()) ] handle = dist.all_gather(params, param.data, group=mpu.get_expert_model_parallel_group(), async_op=True) @@ -261,6 +266,7 @@ def connect_rollout_engines_from_distributed( group_name: str, rollout_engines: Sequence[ActorHandle], engine_gpu_counts: Sequence[int] | None = None, + max_retries: int = 3, ) -> dist.ProcessGroup: """Create NCCL group: training rank 0 + all engine GPUs. Blocks until joined. @@ -273,37 +279,55 @@ def connect_rollout_engines_from_distributed( engine_gpu_counts = [args.rollout_num_gpus_per_engine] * len(rollout_engines) master_address = ray._private.services.get_node_ip_address() - with socket.socket() as sock: - sock.bind(("", 0)) - master_port = sock.getsockname()[1] world_size = sum(engine_gpu_counts) + 1 # +1 for training rank 0 - # Compute cumulative rank offsets: engine i starts at cumulative[i] + 1. cumulative = [0] for c in engine_gpu_counts: cumulative.append(cumulative[-1] + c) - refs = [ - engine.init_weights_update_group.remote( - master_address, - master_port, - cumulative[i] + 1, - world_size, - group_name, - backend="nccl", - ) - for i, engine in enumerate(rollout_engines) - ] - model_update_groups = init_process_group( - backend="nccl", - init_method=f"tcp://{master_address}:{master_port}", - world_size=world_size, - rank=0, - group_name=group_name, - timeout=timedelta(minutes=args.distributed_timeout_minutes), - ) - ray.get(refs) - return model_update_groups + last_error = None + dist_backend = device_utils.get_dist_backend() + for attempt in range(1, max_retries + 1): + with socket.socket() as sock: + sock.bind(("", 0)) + master_port = sock.getsockname()[1] + + refs = [ + engine.init_weights_update_group.remote( + master_address, + master_port, + cumulative[i] + 1, + world_size, + group_name, + backend=dist_backend, + ) + for i, engine in enumerate(rollout_engines) + ] + try: + model_update_groups = init_process_group( + backend=dist_backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name=group_name, + timeout=timedelta(minutes=args.distributed_timeout_minutes), + ) + ray.get(refs) + return model_update_groups + except Exception as e: + last_error = e + logger.warning( + f"Failed to connect rollout engines (attempt {attempt}/{max_retries}, port={master_port}): {e}", + exc_info=(attempt == max_retries), + ) + try: + ray.get(refs, timeout=5) + except Exception: + pass + if attempt < max_retries: + time.sleep(5.0 * attempt) + + raise RuntimeError(f"Failed to connect rollout engines after {max_retries} attempts") from last_error def disconnect_rollout_engines_from_distributed(args, group_name, model_update_groups, rollout_engines): @@ -311,6 +335,8 @@ def disconnect_rollout_engines_from_distributed(args, group_name, model_update_g refs = [engine.destroy_weights_update_group.remote(group_name) for engine in rollout_engines] dist.destroy_process_group(model_update_groups) ray.get(refs) + # Wait for NCCL socket ports to be released by the OS + time.sleep(2.0) def update_weights_from_distributed( diff --git a/relax/backends/sglang/arguments.py b/relax/backends/sglang/arguments.py index 18dacc9e..76ea6d48 100644 --- a/relax/backends/sglang/arguments.py +++ b/relax/backends/sglang/arguments.py @@ -105,6 +105,86 @@ def add_sglang_arguments(parser): parser.set_defaults(router_balance_abs_threshold=10, router_balance_rel_threshold=1.2) parser.add_argument("--sglang-server-concurrency", type=int, default=512) + # SGLang profiling arguments — triggers /start_profile and /stop_profile HTTP API + # on all SGLang engines during rollout inference. + # Can also be used standalone via: python tools/profile_rollout.py + parser.add_argument( + "--sglang-profile", + action="store_true", + default=False, + help="Enable torch profiling on SGLang engines during rollout. Profile traces will be saved per rollout step.", + ) + parser.add_argument( + "--sglang-profile-output-dir", + type=str, + default=None, + help=("Output directory for SGLang profile traces. Defaults to traces//sglang_trace."), + ) + parser.add_argument( + "--sglang-profile-num-steps", + type=int, + default=3, + help="Number of SGLang forward steps to profile per rollout. " + "If -1, profiles the entire rollout step until stop_profile is called.", + ) + parser.add_argument( + "--sglang-profile-activities", + type=str, + nargs="+", + default=["CPU", "GPU"], + help="Activities to profile (e.g., CPU GPU).", + ) + parser.add_argument( + "--sglang-profile-by-stage", + action="store_true", + default=False, + help="Profile by stage (prefill/decode) separately.", + ) + parser.add_argument( + "--sglang-profile-with-stack", + action="store_true", + default=False, + help="Record call stack in profile traces.", + ) + parser.add_argument( + "--sglang-profile-record-shapes", + action="store_true", + default=False, + help="Record tensor shapes in profile traces.", + ) + parser.add_argument( + "--sglang-profile-steps", + type=int, + nargs="+", + default=None, + help=( + "List of absolute rollout step IDs (0-indexed) at which to enable SGLang profiling. " + "Takes precedence over --sglang-profile-step-start/end when set. " + "Example: --sglang-profile-steps 3 10 50" + ), + ) + parser.add_argument( + "--sglang-profile-step-start", + type=int, + default=None, + help=( + "Start of the rollout step range for SGLang profiling (inclusive, 0-indexed). " + "Used together with --sglang-profile-step-end to specify a contiguous range. " + "Ignored if --sglang-profile-steps is set." + ), + ) + parser.add_argument( + "--sglang-profile-step-end", + type=int, + default=None, + help=( + "End of the rollout step range for SGLang profiling (inclusive, 0-indexed). " + "Used together with --sglang-profile-step-start to specify a contiguous range. " + "Ignored if --sglang-profile-steps is set. " + "Example: --sglang-profile-step-start 2 --sglang-profile-step-end 4 profiles steps 2, 3, 4." + ), + ) + old_add_argument = parser.add_argument skipped_args = [ diff --git a/relax/backends/sglang/routing_replay_patch.py b/relax/backends/sglang/routing_replay_patch.py index 651e7f54..7d515565 100644 --- a/relax/backends/sglang/routing_replay_patch.py +++ b/relax/backends/sglang/routing_replay_patch.py @@ -52,6 +52,8 @@ import torch +from relax.utils import device as device_utils + logger = logging.getLogger(__name__) @@ -95,8 +97,8 @@ def _patched_init(self, *args, **kwargs): self._pinned_loc = torch.zeros(max_batch, dtype=torch.int64, device="cpu", pin_memory=True) # Dedicated copy stream + event. - self._copy_stream = torch.cuda.Stream(device=dev_buf.device) - self._copy_event = torch.cuda.Event() + self._copy_stream = device_utils.Stream(device=dev_buf.device) + self._copy_event = device_utils.Event() # Pending scatter state. self._pending_n = 0 # 0 means nothing pending @@ -142,7 +144,7 @@ def _patched_sync(self, forward_batch, can_run_graph, cuda_graph_batch): # In overlap-scheduler mode this is the *forward_stream*; without # overlap it is the default stream. We need this reference so that # copy_stream can order itself after the GPU→GPU staging copy below. - active_stream = torch.cuda.current_stream(self.device_cache.buffer.device) + active_stream = device_utils.current_stream(self.device_cache.buffer.device) # 1) GPU→GPU snapshot on the active stream — fast, no sync. self._staging_buffer[:n_tok].copy_(self.device_cache.buffer[local_start_pos:local_end_pos]) @@ -150,7 +152,7 @@ def _patched_sync(self, forward_batch, can_run_graph, cuda_graph_batch): # 2) On copy stream: async copies to pinned CPU buffers. # copy_stream waits on active_stream so the staging snapshot # above completes before we start reading it. - with torch.cuda.stream(self._copy_stream): + with device_utils.stream_context(self._copy_stream): self._copy_stream.wait_stream(active_stream) # 2a) Routing data: staging[:n_tok, :, :topk] → pinned_staging self._pinned_staging[:n_tok, :, :topk].copy_(self._staging_buffer[:n_tok, :, :topk], non_blocking=True) diff --git a/relax/backends/sglang/sglang_engine.py b/relax/backends/sglang/sglang_engine.py index e4a15be6..b23a2b71 100644 --- a/relax/backends/sglang/sglang_engine.py +++ b/relax/backends/sglang/sglang_engine.py @@ -20,6 +20,7 @@ from relax.distributed.checkpoint_service.client.engine import create_client from relax.distributed.ray.ray_actor import RayActor +from relax.utils import device as device_utils from relax.utils.async_utils import run from relax.utils.device_utils import to_local_visible_device_index from relax.utils.http_utils import get_host_info @@ -43,7 +44,22 @@ def get_base_gpu_id(args, rank): def _to_local_gpu_id(physical_gpu_id: int) -> int: - return to_local_visible_device_index(physical_gpu_id) + visible_env = device_utils.get_visible_devices_env_var() + cvd = os.environ.get(visible_env) + if not cvd: + return physical_gpu_id # no remapping + # Visible devices can be like "4,5,6,7" + visible = [int(x) for x in cvd.split(",") if x.strip() != ""] + # In a remapped process, valid torch device indices are 0..len(visible)-1 + if physical_gpu_id in visible: + return visible.index(physical_gpu_id) + # If we're already getting local IDs, allow them + if 0 <= physical_gpu_id < len(visible): + return physical_gpu_id + raise RuntimeError( + f"Device id {physical_gpu_id} is not valid under {visible_env}={cvd}. " + f"Expected one of {visible} (physical) or 0..{len(visible) - 1} (local)." + ) def _patched_run_scheduler_process(*args, **kwargs): @@ -721,40 +737,6 @@ def post_process_weights( }, ) - def start_profile( - self, - # The output directory - output_dir: str | None = None, - # If set, it profile as many as this number of steps. - # If it is set, profiling is automatically stopped after this step, and - # the caller doesn't need to run stop_profile. - start_step: int | None = None, - num_steps: int | None = None, - activities: list[str] | None = None, - profile_by_stage: bool = False, - with_stack: bool | None = None, - record_shapes: bool | None = None, - ): - response = requests.post( - f"http://{self.server_host}:{self.server_port}/start_profile", - json={ - "output_dir": output_dir, - "start_step": start_step, - "num_steps": num_steps, - "activities": activities, - "profile_by_stage": profile_by_stage, - "with_stack": with_stack, - "record_shapes": record_shapes, - }, - ) - response.raise_for_status() - return response - - def stop_profile(self): - response = requests.post(f"http://{self.server_host}:{self.server_port}/stop_profile", json={}) - response.raise_for_status() - return response - def simulate_crash(self): if self.args.rollout_external or not getattr(self, "process", None): logger.info( diff --git a/relax/core/controller.py b/relax/core/controller.py index 0c9b5f68..2e608867 100644 --- a/relax/core/controller.py +++ b/relax/core/controller.py @@ -16,6 +16,7 @@ from relax.core.registry import ALGOS, ROLES, process_role from relax.core.service import Service, create_placement_group from relax.distributed.checkpoint_service.coordinator.service import create_dcs_deployment +from relax.utils import device as device_utils from relax.utils.async_utils import run, shutdown_async_loop from relax.utils.health_system import HealthManager from relax.utils.logging_utils import get_logger @@ -238,7 +239,8 @@ def _validate_gpu_resources(self, roles_to_create, colocate, actor_rollout_pg_ro total_required = sum(num_gpus for _, _, num_gpus, _ in roles_to_create) cluster_resources = ray.cluster_resources() - total_available = int(cluster_resources.get("GPU", 0)) + accel_resource = device_utils.get_ray_accelerator_name() + total_available = int(cluster_resources.get(accel_resource, 0)) logger.info( f"Resource validation: required GPUs={total_required}, cluster GPUs={total_available}, colocate={colocate}" diff --git a/relax/core/service.py b/relax/core/service.py index 52ae3eba..eb9441c2 100644 --- a/relax/core/service.py +++ b/relax/core/service.py @@ -12,6 +12,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from relax.distributed.ray.placement_group import InfoActor, sort_key +from relax.utils import device as device_utils from relax.utils.logging_utils import get_logger from relax.utils.utils import get_serve_url, recovery_load_path @@ -295,7 +296,8 @@ def _ensure_placement_group(self) -> Optional[Any]: def create_placement_group(num_gpus): """Create a placement group with the specified number of GPUs.""" - bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)] + accel_resource = device_utils.get_ray_accelerator_name() + bundles = [{accel_resource: 1, "CPU": 1} for _ in range(num_gpus)] pg = placement_group(bundles, strategy="PACK") num_bundles = len(bundles) ray.get(pg.ready()) diff --git a/relax/distributed/checkpoint_service/backends/device_direct.py b/relax/distributed/checkpoint_service/backends/device_direct.py index 7ac0717c..1f3e8cba 100644 --- a/relax/distributed/checkpoint_service/backends/device_direct.py +++ b/relax/distributed/checkpoint_service/backends/device_direct.py @@ -15,6 +15,7 @@ import asyncio import logging +import os import re import socket import time @@ -39,6 +40,7 @@ from relax.distributed.checkpoint_service.backends.base import CommBackend, TensorFusion from relax.distributed.checkpoint_service.config import BackendType, RoleInfo from relax.distributed.checkpoint_service.utils import load_weight +from relax.utils import device as device_utils from relax.utils.distributed_utils import get_gloo_group, init_process_group from relax.utils.external.megatron_bridge_compat import ensure_megatron_bridge_importable from relax.utils.logging_utils import get_logger @@ -99,7 +101,7 @@ def __init__( self.coordinator_url = coordinator_url self.lock = lock self.timeout_seconds = timeout_seconds - self.device = next(model[0].parameters()).device if model else torch.cuda.current_device() + self.device = next(model[0].parameters()).device if model else device_utils.current_device() self._comm_stream: Optional[Any] = None # CUDA stream self._thread_pool = ThreadPoolExecutor(max_workers=4) @@ -107,6 +109,9 @@ def __init__( # For recv, we need to know tensor shapes in advance or use a metadata channel self._pending_recvs: Dict[str, asyncio.Future] = {} + self._group_name: Optional[str] = None + self._rollout_group_generation = 0 + self._rollout_topology_signature: Optional[tuple[tuple[int, str, int, int], ...]] = None self._model_update_groups = None self._model_update_groups_for_actor_fwd_ref = None @@ -115,12 +120,13 @@ def __init__( # Ray actors for rollout communication self.rollout_engines: Dict[int, Any] = {} # rank -> Ray actor handle - torch.cuda.set_device(self.device) + device_utils.set_device(self.device) # Bridge-based HF weight converter (lazy-initialized on first use) self._use_bridge = getattr(args, "megatron_to_hf_mode", None) == "bridge" self._bridge_task_map: Optional[Dict[str, Any]] = None # global_param_name -> WeightConversionTask self._bridge_mapping_registry = None # MegatronMappingRegistry for dynamic lookups + self._bridge_expert_transposes_down: bool = True # set in _init_bridge_tasks def _init_bridge_tasks(self) -> None: """Lazily initialize Bridge conversion tasks and build a lookup table. @@ -205,6 +211,16 @@ def _init_bridge_tasks(self) -> None: inner_tp._detected_type = inner_tp._detect_parallelism_type(task.megatron_module) inner_tp._mapping = inner_tp._get_or_create_mapping(inner_tp._detected_type) + # Detect whether the Bridge's ExpertMLPDownProjMapping applies a + # transpose in megatron_to_hf (Qwen3-VL does, Qwen3.5 does not). + # Used by _convert_to_hf_bridge to decide whether to undo the transpose. + self._bridge_expert_transposes_down = False + for task in self._bridge_task_map.values(): + cls = type(task.mapping) + if cls.__name__ == "ExpertMLPDownProjMapping": + self._bridge_expert_transposes_down = "megatron_to_hf" in cls.__dict__ + break + logger.info(f"Bridge task map initialized with {len(self._bridge_task_map)} local tasks") @staticmethod @@ -392,16 +408,21 @@ def _noop_gather_from_ep_ranks(self_m, megatron_weights, megatron_module, hf_par # ── Post-process expert weights ────────────────────────────────── # Bridge's ExpertMLPGateUpProjMapping and ExpertMLPDownProjMapping - # (used by Qwen3-VL MoE) apply an extra ``.transpose(-1, -2)`` in - # their ``megatron_to_hf`` methods, assuming Megatron stores expert - # weights in column-major order. However, the raw ``convert_to_hf`` - # does NOT transpose expert weights — Megatron's expert weights are - # already in the same layout as HF. We must undo Bridge's transpose - # to match the format that SGLang / ``convert_to_hf`` expects. + # apply transformations that differ by model family: + # + # **Qwen3-VL** (qwen3_vl_bridge.py): + # gate_up_proj: transpose each half then stack → [2, D_out, D_in] + # down_proj: transpose → [D_in, D_out] + # We must undo the transpose. + # + # **Qwen3.5** (qwen35_vl_bridge.py): + # gate_up_proj: cat without transpose → [2*H, D] (2-D) + # down_proj: no transpose (AutoMapping) → [H, D] (2-D) + # No un-transpose needed; just split the fused tensor. # # Additionally, Bridge outputs fused names without expert_id: - # - ``...experts.gate_up_proj`` with shape [2, D_out, D_in] - # - ``...experts.down_proj`` with shape [D_in, D_out] + # - ``...experts.gate_up_proj`` + # - ``...experts.down_proj`` # We split into per-expert format with correct names and shapes: # - ``...experts.{E}.gate_proj.weight`` [H, D] # - ``...experts.{E}.up_proj.weight`` [H, D] @@ -412,19 +433,28 @@ def _noop_gather_from_ep_ranks(self_m, megatron_weights, megatron_module, hf_par postprocessed: list[tuple[str, torch.Tensor]] = [] for hf_name, tensor in converted_named_tensors: if hf_name.endswith(".experts.gate_up_proj"): - # Bridge output: [2, D_out, D_in] (transposed by Bridge) - # Undo transpose on each slice: [D_out, D_in] -> [D_in, D_out] - gate_tensor = tensor[0].transpose(-1, -2).contiguous() - up_tensor = tensor[1].transpose(-1, -2).contiguous() base = hf_name[: -len(".gate_up_proj")] + if tensor.ndim == 3: + # Qwen3-VL style: [2, D_out, D_in] (transposed by Bridge) + # Undo transpose on each slice: [D_out, D_in] -> [D_in, D_out] + gate_tensor = tensor[0].transpose(-1, -2).contiguous() + up_tensor = tensor[1].transpose(-1, -2).contiguous() + else: + # Qwen3.5 style: [2*H, D] (cat, no transpose by Bridge) + # Split along dim 0 into two [H, D] tensors + gate_tensor, up_tensor = tensor.chunk(2, dim=0) postprocessed.append((f"{base}.{expert_id}.gate_proj.weight", gate_tensor)) postprocessed.append((f"{base}.{expert_id}.up_proj.weight", up_tensor)) elif hf_name.endswith(".experts.down_proj"): - # Bridge output: transposed — undo to match raw convert_to_hf base = hf_name[: -len(".down_proj")] - postprocessed.append( - (f"{base}.{expert_id}.down_proj.weight", tensor.transpose(-1, -2).contiguous()) - ) + if tensor.ndim == 2 and not self._bridge_expert_transposes_down: + # Qwen3.5 style: AutoMapping, no transpose — already [H, D] + postprocessed.append((f"{base}.{expert_id}.down_proj.weight", tensor)) + else: + # Qwen3-VL style: transposed — undo to match raw convert_to_hf + postprocessed.append( + (f"{base}.{expert_id}.down_proj.weight", tensor.transpose(-1, -2).contiguous()) + ) else: postprocessed.append((hf_name, tensor)) converted_named_tensors = postprocessed @@ -542,6 +572,21 @@ def _update_rollout_engines(self): _MASTER_PORT_MIN = 11000 _MASTER_PORT_MAX = 11999 + def _get_rollout_topology_signature(self) -> tuple[tuple[int, str, int, int], ...]: + default_gpus = self.args.rollout_num_gpus_per_engine + signature = [] + for rank, role_info in sorted(self.rollout_topology.items(), key=lambda kv: int(kv[0])): + metadata = role_info.get("metadata") if isinstance(role_info, dict) else {} + signature.append( + ( + int(rank), + str(role_info.get("ip")), + int(role_info.get("port")), + int((metadata or {}).get("num_gpus_per_engine", default_gpus)), + ) + ) + return tuple(signature) + @staticmethod def _find_free_port_in_range(port_min: int, port_max: int) -> int: """Find a free port within [port_min, port_max] by attempting to bind. @@ -552,13 +597,23 @@ def _find_free_port_in_range(port_min: int, port_max: int) -> int: ports = list(range(port_min, port_max + 1)) random.shuffle(ports) + skipped_ports: list[tuple[int, int | None, str]] = [] for port in ports: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(("", port)) + logger.info( + "Selected rollout weight update master port %s from range [%s, %s] after skipping %s ports", + port, + port_min, + port_max, + len(skipped_ports), + ) + if skipped_ports: + logger.info("Skipped rollout weight update master ports sample: %s", skipped_ports[:10]) return port - except OSError: + except OSError as exc: + skipped_ports.append((port, exc.errno, exc.strerror or str(exc))) continue raise RuntimeError(f"No free port available in range [{port_min}, {port_max}]") @@ -574,27 +629,51 @@ def init_process_group_for_rollout(self, topology_data: Optional[Dict] = None) - if self._is_pp_src_rank: pp_rank = mpu.get_pipeline_model_parallel_rank() master_address = ray._private.services.get_node_ip_address() - self._group_name = f"slime-pp_{pp_rank}" + base_group_name = f"slime-pp_{pp_rank}" if topology_data is None: raise RuntimeError("topology_data is required for init_process_group_for_rollout") self.rollout_topology = topology_data.get("nodes", {}).get("rollout", {}) + rollout_topology_signature = self._get_rollout_topology_signature() self._create_rollout_engines(self.rollout_topology) self._update_rollout_engines() + reuse_enabled = os.getenv("RELAX_ROLLOUT_PG_REUSE", "1").lower() not in ("0", "false", "no") + if ( + reuse_enabled + and self._model_update_groups is not None + and self._rollout_topology_signature == rollout_topology_signature + ): + logger.info( + "Reusing rollout process group: group_name=%s, topology_signature=%s", + self._group_name, + rollout_topology_signature, + ) + return + if self._model_update_groups is not None: + old_group_name = self._group_name or base_group_name try: - logger.info("Destroying old process group...") - destroy_payload = {"group_name": self._group_name} + logger.info( + "Destroying old rollout process group: group_name=%s, local_group=%s, rollout_ranks=%s", + old_group_name, + self._model_update_groups, + sorted(self.rollout_engines), + ) + destroy_payload = {"group_name": old_group_name} futures = self._batch_request("/destroy_weights_update_group", destroy_payload) dist.destroy_process_group(self._model_update_groups) ray.get(futures) + logger.info("Destroyed old rollout process group: group_name=%s", old_group_name) self._model_update_groups = None + # Wait for NCCL socket ports to be released by the OS + time.sleep(2.0) except Exception as e: logger.warning(f"Error destroying old process group: {e}") self._model_update_groups = None + self._rollout_topology_signature = None default_gpus = self.args.rollout_num_gpus_per_engine cumulative_offset = 1 @@ -606,32 +685,59 @@ def init_process_group_for_rollout(self, topology_data: Optional[Dict] = None) - cumulative_offset += gpus_for_node world_size = cumulative_offset - master_port = self._find_free_port_in_range(self._MASTER_PORT_MIN, self._MASTER_PORT_MAX) - - # Prepare init payloads for each rollout node - init_payloads = {} - for rank, role_info in self.rollout_topology.items(): - init_payloads[int(rank)] = { - "master_address": master_address, - "master_port": master_port, - "rank_offset": rank_offsets[int(rank)], - "world_size": world_size, - "group_name": self._group_name, - "backend": self.backend_type, - } + max_retries = 3 + last_error = None + for attempt in range(1, max_retries + 1): + master_port = self._find_free_port_in_range(self._MASTER_PORT_MIN, self._MASTER_PORT_MAX) + + init_payloads = {} + for rank, role_info in self.rollout_topology.items(): + init_payloads[int(rank)] = { + "master_address": master_address, + "master_port": master_port, + "rank_offset": rank_offsets[int(rank)], + "world_size": world_size, + "group_name": self._group_name, + "backend": self.backend_type, + } - logger.info(f"Sending init_weights_update_group to {len(self.rollout_topology)} rollout nodes...") - futures = self._batch_request("/init_weights_update_group", init_payloads, get_rank=True) + logger.info( + f"Sending init_weights_update_group to {len(self.rollout_topology)} rollout nodes " + f"(attempt {attempt}/{max_retries}, port={master_port})..." + ) + futures = self._batch_request("/init_weights_update_group", init_payloads, get_rank=True) - self._model_update_groups = init_process_group( - backend=self.backend_type, - init_method=f"tcp://{master_address}:{master_port}", - world_size=world_size, - rank=0, - group_name=self._group_name, - timeout=timedelta(seconds=180), - ) - ray.get(futures) + try: + self._model_update_groups = init_process_group( + backend=self.backend_type, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name=self._group_name, + timeout=timedelta(seconds=180), + ) + ray.get(futures) + last_error = None + break + except Exception as e: + last_error = e + logger.warning( + f"Failed to init process group for rollout (attempt {attempt}/{max_retries}, " + f"port={master_port}): {e}", + exc_info=(attempt == max_retries), + ) + self._model_update_groups = None + try: + ray.get(futures, timeout=5) + except Exception: + pass + if attempt < max_retries: + time.sleep(5.0 * attempt) + + if last_error is not None: + raise RuntimeError( + f"Failed to init process group for rollout after {max_retries} attempts" + ) from last_error def init_process_groups_for_actor_fwd_ref(self, topology_data) -> None: """Initialize process groups used for actor -> actor_fwd weight sync. @@ -806,7 +912,7 @@ def update_weights_for_rollout(self, rollout_only=False, actor_fwd_only=False) - # allocator keeps large reserved blocks that are internally # fragmented, which can cause OOM when the optimizer later tries # to allocate contiguous Adam state buffers. - torch.cuda.empty_cache() + device_utils.empty_cache() def _update_weight_from_distributed( self, @@ -952,6 +1058,15 @@ def _update_bucket_weights_from_distributed( "weight_version": str(self.weight_version), "flush_cache": False, } + logger.info( + "Sending rollout distributed weight bucket: group_name=%s, weight_version=%s, tensors=%s, " + "first_tensor=%s, rollout_ranks=%s", + self._group_name, + self.weight_version, + len(converted_named_tensors), + converted_named_tensors[0][0] if converted_named_tensors else None, + sorted(self.rollout_engines), + ) # Send weight update to all rollout nodes via Ray actors futures = self._batch_request("/update_weights_from_distributed", weight_payload) @@ -962,6 +1077,12 @@ def _update_bucket_weights_from_distributed( for handle in handles: handle.wait() ray.get(futures) # Ensure remote update completes + logger.info( + "Completed rollout distributed weight bucket: group_name=%s, weight_version=%s, tensors=%s", + self._group_name, + self.weight_version, + len(converted_named_tensors), + ) ray.get(self.lock.release.remote()) if pbar is not None: diff --git a/relax/distributed/checkpoint_service/utils.py b/relax/distributed/checkpoint_service/utils.py index 59dc98b0..51d6acbb 100644 --- a/relax/distributed/checkpoint_service/utils.py +++ b/relax/distributed/checkpoint_service/utils.py @@ -87,8 +87,13 @@ def chunk_param( tp_rank = mpu.get_tensor_model_parallel_rank() # 4. Verify stride + # NOTE: Megatron-LM (megatron/core/transformer/mlp.py) sets partition_stride=2 + # for GLU/SwiGLU linear_fc1 layers. The rechunk logic below handles this correctly. partition_dim = target_param.partition_dim - assert getattr(target_param, "partition_stride", 1) == 1, "partition_stride != 1 is not supported" + if "linear_fc1" not in name: + assert getattr(target_param, "partition_stride", 1) == 1, ( + f"{name}: partition_stride={getattr(target_param, 'partition_stride', 1)} != 1 is not supported" + ) # 5. Workaround grouped MoE partition bug for linear_fc2.weight effective_partition_dim = partition_dim diff --git a/relax/distributed/ray/rollout.py b/relax/distributed/ray/rollout.py index 18f2505a..56901516 100644 --- a/relax/distributed/ray/rollout.py +++ b/relax/distributed/ray/rollout.py @@ -23,6 +23,7 @@ from relax.backends.sglang.sglang_engine import SGLangEngine from relax.engine.rollout.base_types import call_rollout_fn +from relax.utils import device as device_utils from relax.utils import tracking_utils from relax.utils.health_monitor import RolloutHealthMonitor from relax.utils.http_utils import SLIME_HOST_IP_ENV, _wrap_ipv6, find_available_port, get_host_info, init_http_client @@ -30,6 +31,7 @@ from relax.utils.metrics.metric_checker import MetricChecker from relax.utils.metrics.metric_utils import ( compute_pass_rate, + compute_rollout_explicit_reward_metrics, compute_rollout_step, compute_statistics, dict_add_prefix, @@ -1425,7 +1427,8 @@ async def _scale_out_ray_native(self, request: ScaleOutRequest) -> None: per_replica_pgs = [] for i in range(request.num_replicas): num_gpus = gpus_per_engine - bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)] + accel_resource = device_utils.get_ray_accelerator_name() + bundles = [{accel_resource: 1, "CPU": 1} for _ in range(num_gpus)] pg = ray.util.placement_group(bundles, strategy="PACK") per_replica_pgs.append(pg) @@ -2065,13 +2068,14 @@ async def _sync_single_engine_weights( ) try: + dist_backend = device_utils.get_dist_backend() init_seed_ref = seed_engine.init_weights_send_group_for_remote_instance.remote( master_address=master_address, ports=ports_str, group_rank=0, world_size=2, group_name=group_name, - backend="nccl", + backend=dist_backend, ) init_new_ref = new_engine.init_weights_send_group_for_remote_instance.remote( master_address=master_address, @@ -2079,7 +2083,7 @@ async def _sync_single_engine_weights( group_rank=1, world_size=2, group_name=group_name, - backend="nccl", + backend=dist_backend, ) init_results = await asyncio.wait_for( asyncio.gather(init_seed_ref, init_new_ref), @@ -3619,6 +3623,7 @@ def compute_metrics_from_samples(args, samples): log_dict = {} log_dict |= dict_add_prefix(compute_statistics(response_lengths), "response_len/") + log_dict |= compute_rollout_explicit_reward_metrics(args, samples) log_dict |= _compute_zero_std_metrics(args, samples) log_dict |= _compute_reward_cat_metrics(args, samples) log_dict["repetition_frac"] = np.mean([int(has_repetition(s.response)) for s in samples]).item() diff --git a/relax/distributed/ray/train_actor.py b/relax/distributed/ray/train_actor.py index ca6923d4..9689a28b 100644 --- a/relax/distributed/ray/train_actor.py +++ b/relax/distributed/ray/train_actor.py @@ -11,7 +11,7 @@ import relax.utils.training.eval_config from relax.distributed.ray.ray_actor import RayActor -from relax.utils.device_utils import get_visible_devices, to_local_visible_device_index +from relax.utils import device as device_utils from relax.utils.distributed_utils import init_gloo_group from relax.utils.logging_utils import get_logger from relax.utils.memory_utils import clear_memory, print_memory @@ -37,16 +37,17 @@ def _configure_visible_devices_for_current_actor() -> None: joined_ids = ",".join(assigned_gpu_ids) if torch.version.hip is not None: - os.environ["CUDA_VISIBLE_DEVICES"] = joined_ids + # On ROCm, Ray must not rewrite HIP_VISIBLE_DEVICES. Keep the full + # visible device list and select the assigned device with set_device(). + os.environ.pop("CUDA_VISIBLE_DEVICES", None) os.environ.pop("ROCR_VISIBLE_DEVICES", None) - os.environ.pop("HIP_VISIBLE_DEVICES", None) else: os.environ["CUDA_VISIBLE_DEVICES"] = joined_ids def get_local_gpu_id(): - visible_devices = get_visible_devices() - if not visible_devices: + cvd = os.environ.get(device_utils.get_visible_devices_env_var(), None) + if cvd is None: return ray.get_gpu_ids()[0] return to_local_visible_device_index(int(ray.get_gpu_ids()[0])) @@ -82,14 +83,7 @@ def init(self, args, role, with_ref=False, with_opd_teacher=False): torch.serialization.add_safe_globals([relax.utils.training.eval_config.EvalDatasetConfig]) local_rank = int(os.environ.get("LOCAL_RANK", 0)) - logger.info( - "Initializing TrainRayActor rank=%s local_rank=%s visible_devices=%s hip=%s", - self._rank, - local_rank, - get_visible_devices(), - torch.version.hip is not None, - ) - torch.cuda.set_device(local_rank) + device_utils.set_device(f"{device_utils.get_device_name()}:{local_rank}") backend = args.distributed_backend @@ -102,27 +96,8 @@ def init(self, args, role, with_ref=False, with_opd_teacher=False): args.rank = dist.get_rank() args.world_size = dist.get_world_size() - try: - if torch.version.hip is not None: - logger.info("Detected ROCm/HIP environment, skipping NUMA affinity setup") - # will find the coresponding API to implement ROCm version as below - else: - import pynvml - - pynvml.nvmlInit() - - local_rank = int(os.environ["RANK"]) % args.num_gpus_per_node - - handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank) - pynvml.nvmlDeviceSetCpuAffinity(handle) - - logger.info(f"Set NUMA affinity for GPU {local_rank}") - pynvml.nvmlShutdown() - - except ImportError: - logger.info("Warning: pynvml not available, skipping NUMA affinity setup") - except Exception as e: - logger.info(f"Warning: Failed to set NUMA affinity: {e}") + numa_local_rank = int(os.environ["RANK"]) % args.num_gpus_per_node + device_utils.set_numa_affinity(numa_local_rank) def clear_memory(self): print_memory("before TrainRayActor.clear_memory") diff --git a/relax/distributed/ray/utils.py b/relax/distributed/ray/utils.py index 1918eaa1..697e7b50 100644 --- a/relax/distributed/ray/utils.py +++ b/relax/distributed/ray/utils.py @@ -1,10 +1,11 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. # Adapted from https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ray/utils.py#L1 import os import ray -import torch from relax.distributed.ray.ray_actor import RayActor +from relax.utils import device as device_utils # Refer to @@ -17,6 +18,7 @@ # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98 NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [ "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES", "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES", "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES", @@ -31,8 +33,8 @@ def ray_noset_visible_devices(env_vars=os.environ): def get_physical_gpu_id(): - device = torch.cuda.current_device() - props = torch.cuda.get_device_properties(device) + device = device_utils.current_device() + props = device_utils.get_device_properties(device) return str(props.uuid) diff --git a/relax/engine/rollout/data_source.py b/relax/engine/rollout/data_source.py index b5791348..784eb067 100644 --- a/relax/engine/rollout/data_source.py +++ b/relax/engine/rollout/data_source.py @@ -1,7 +1,6 @@ # Copyright (c) 2026 Relax Authors. All Rights Reserved. import abc -import copy import os from pathlib import Path @@ -18,6 +17,25 @@ logger = get_logger(__name__) +def _shallow_copy_sample(src: Sample) -> Sample: + """Create a lightweight copy of a Sample that *shares* heavy read-only + payloads (``multimodal_inputs``) with the source.""" + new = Sample.__new__(Sample) + new.__dict__.update(src.__dict__) + # Shallow-copy mutable containers that downstream code mutates in-place. + new.tokens = list(src.tokens) + new.rollout_tokens = list(src.rollout_tokens) + new.weight_versions = list(src.weight_versions) + new.metadata = dict(src.metadata) + # Per-sample accumulators — create fresh instances. + new.spec_info = Sample.SpecInfo() + new.prefix_cache_info = Sample.PrefixCacheInfo() + # ``multimodal_inputs`` is read-only downstream — share the reference. + # ``multimodal_train_inputs`` is *set* (not mutated) per-sample by the + # processor, so sharing the initial ``None`` is fine. + return new + + def _create_dataset(args, tokenizer, processor, multimodal_config=None): """Factory function to create dataset based on configuration. @@ -39,8 +57,15 @@ def _create_dataset(args, tokenizer, processor, multimodal_config=None): from relax.utils.data.streaming_dataset import StreamingDataset buffer_size = getattr(args, "streaming_buffer_size", 10000) - - logger.info(f"Using StreamingDataset with buffer_size={buffer_size}") + prefetch_chunk_size = getattr(args, "prefetch_chunk_size", 32) + prefetch_max_cached = getattr(args, "prefetch_max_cached", 256) + prefetch_num_workers = getattr(args, "prefetch_num_workers", 1) + + logger.info( + f"Using StreamingDataset with buffer_size={buffer_size}, " + f"prefetch_chunk_size={prefetch_chunk_size}, prefetch_max_cached={prefetch_max_cached}, " + f"prefetch_num_workers={prefetch_num_workers}" + ) return StreamingDataset( path=args.prompt_data, tokenizer=tokenizer, @@ -57,6 +82,9 @@ def _create_dataset(args, tokenizer, processor, multimodal_config=None): use_audio_in_video=args.use_audio_in_video, seed=args.rollout_seed, buffer_size=buffer_size, + prefetch_chunk_size=prefetch_chunk_size, + prefetch_max_cached=prefetch_max_cached, + prefetch_num_workers=prefetch_num_workers, multimodal_config=multimodal_config, ) else: @@ -172,7 +200,7 @@ def get_samples(self, num_samples): for prompt_sample in prompt_samples: group = [] for _ in range(self.args.n_samples_per_prompt): - sample = copy.deepcopy(prompt_sample) + sample = _shallow_copy_sample(prompt_sample) sample.group_index = self.sample_group_index sample.index = self.sample_index self.sample_index += 1 diff --git a/relax/engine/rollout/sglang_rollout.py b/relax/engine/rollout/sglang_rollout.py index 7e0d30b6..567ef9b7 100644 --- a/relax/engine/rollout/sglang_rollout.py +++ b/relax/engine/rollout/sglang_rollout.py @@ -36,6 +36,7 @@ from relax.utils.http_utils import get, post from relax.utils.logging_utils import get_logger from relax.utils.misc import SingletonMeta, load_function +from relax.utils.profile_utils import start_sglang_profile, stop_sglang_profile from relax.utils.timer import Timer from relax.utils.training.eval_config import EvalDatasetConfig from relax.utils.training.train_dump_utils import save_debug_rollout_data @@ -117,12 +118,16 @@ def reset(self) -> None: ) # tasks that should not be aborted (abort_count >= partial_rollout_max_aborted_count) self.aborted = False self.evaluating = getattr(self, "evaluating", 0) # preserve eval state across resets + # Pre-fetched data ObjectRef for cross-step overlap. + # Persisted across reset() calls so the ref submitted at the end of + # step N is consumed at the beginning of step N+1. + if not hasattr(self, "prefetched_samples_ref"): + self.prefetched_samples_ref: ray.ObjectRef | None = None def submit_generate_tasks(self, samples: list[list[Sample]]) -> None: max_aborted_count = getattr(self.args, "partial_rollout_max_aborted_count", None) for group in samples: task = asyncio.create_task( - # submit a group of samples as a single task. generate_and_rm_group( self.args, group, @@ -260,7 +265,16 @@ async def generate( _t_mm_encode: float | None = None if sample.multimodal_inputs: - encoded_mm, _t_mm_encode = await _encode_multimodal_inputs(sample.multimodal_inputs) + # Use pre-encoded data from group-level de-dup if available; otherwise encode inline. + pre_encoded = getattr(sample, "_pre_encoded_mm", None) + if pre_encoded is not None: + encoded_mm = pre_encoded + _t_mm_encode = getattr(sample, "_pre_encoded_mm_elapsed", 0.0) + del sample._pre_encoded_mm + if hasattr(sample, "_pre_encoded_mm_elapsed"): + del sample._pre_encoded_mm_elapsed + else: + encoded_mm, _t_mm_encode = await _encode_multimodal_inputs(sample.multimodal_inputs) payload.update(encoded_mm) # Use existing tokens for multi-turn or tokenize the new prompt @@ -486,6 +500,17 @@ async def generate_and_rm_group( if sample.session_id is None: sample.session_id = str(uuid.uuid4()) + # Group-level multimodal encoding de-duplication: when samples in the same + # group share the same multimodal_inputs object (e.g. after shallow-copy in + # data_source), encode once and attach the result to every sample so that + # generate() picks up the pre-encoded data instead of re-encoding per sample. + first_mm = getattr(group[0], "multimodal_inputs", None) + if first_mm is not None and all(getattr(s, "multimodal_inputs", None) is first_mm for s in group[1:]): + encoded_mm, t_enc = await _encode_multimodal_inputs(first_mm) + for sample in group: + sample._pre_encoded_mm = encoded_mm + sample._pre_encoded_mm_elapsed = t_enc + tasks = [] for idx, sample in enumerate(group): current_sampling_params = sampling_params.copy() @@ -616,6 +641,9 @@ async def generate_rollout_async( state = GenerateState(args) + # Start SGLang profiling if enabled + await start_sglang_profile(args, rollout_id) + # instantiate data filters dynamic_filter = ( load_function(args.dynamic_sampling_filter_path) if args.dynamic_sampling_filter_path is not None else None @@ -636,10 +664,21 @@ async def generate_rollout_async( total_transfer_samples = 0 get_samples_times: list[float] = [] + loop = asyncio.get_running_loop() + while len(data) < target_data_size: while state.remaining_batch_size < target_data_size: _t_get_samples = monotonic() - samples = ray.get(data_source.get_samples.remote(args.over_sampling_batch_size, args.fully_async)) + + if state.prefetched_samples_ref is not None: + ref = state.prefetched_samples_ref + state.prefetched_samples_ref = None + logger.info(f"Rollout step {rollout_id}: using pre-fetched data from previous step") + else: + ref = data_source.get_samples.remote(args.over_sampling_batch_size, args.fully_async) + + samples = await loop.run_in_executor(None, ray.get, ref) + get_samples_times.append(monotonic() - _t_get_samples) num_old_samples = len(samples) - args.over_sampling_batch_size logger.info( @@ -761,11 +800,19 @@ async def generate_rollout_async( f"Total yielded: {total_transfer_samples - num_old_samples}/{target_data_size - num_old_samples} for step: {rollout_id}" ) + if not args.fully_async: + state.prefetched_samples_ref = data_source.get_samples.remote(args.over_sampling_batch_size, args.fully_async) + logger.info(f"Rollout step {rollout_id}: pre-submitted data fetch for next step") + logger.info(f"Generator exhausted. Waiting for {len(transfer_tasks)} transfer tasks to complete...") # Wait for all transfer tasks to complete if transfer_tasks: await asyncio.gather(*transfer_tasks) pbar.close() + + # Stop SGLang profiling if enabled (no-op if num_steps was set — SGLang auto-stops) + await stop_sglang_profile(args, rollout_id) + sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0] logger.info( f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {str(sample.label)[:100]}, reward: {sample.reward}", diff --git a/relax/models/__init__.py b/relax/models/__init__.py new file mode 100644 index 00000000..1642b11a --- /dev/null +++ b/relax/models/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. + +try: + from megatron.bridge.models.qwen_omni import ( # type: ignore[attr-defined] # noqa: F401 + Qwen3OmniModelProvider, + Qwen3OmniMoEBridge, + Qwen3OmniMoeModel, + ) +except (ImportError, AttributeError): + from relax.models.qwen_omni.modeling_qwen3_omni.model import Qwen3OmniMoeModel # noqa: F811 + from relax.models.qwen_omni.qwen3_omni_bridge import Qwen3OmniMoEBridge # noqa: F811 + from relax.models.qwen_omni.qwen3_omni_provider import Qwen3OmniModelProvider # noqa: F811 + + +__all__ = [ + "Qwen3OmniMoEBridge", + "Qwen3OmniMoeModel", + "Qwen3OmniModelProvider", +] diff --git a/relax/models/qwen_omni/__init__.py b/relax/models/qwen_omni/__init__.py new file mode 100644 index 00000000..9f386360 --- /dev/null +++ b/relax/models/qwen_omni/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. diff --git a/relax/models/qwen_omni/modeling_qwen3_omni/__init__.py b/relax/models/qwen_omni/modeling_qwen3_omni/__init__.py new file mode 100644 index 00000000..f8c8c3e9 --- /dev/null +++ b/relax/models/qwen_omni/modeling_qwen3_omni/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. + +"""Qwen3 Omni model providers and configurations.""" + +# Core model components +# Bridges for HuggingFace to Megatron conversion +from relax.models.qwen_omni.modeling_qwen3_omni.model import Qwen3OmniMoeModel # noqa: F401 +from relax.models.qwen_omni.qwen3_omni_bridge import Qwen3OmniMoEBridge + +# Dense and MoE model providers +from relax.models.qwen_omni.qwen3_omni_provider import Qwen3OmniModelProvider + + +__all__ = [ + "Qwen3OmniMoeModel", + "Qwen3OmniMoEBridge", + "Qwen3OmniModelProvider", +] diff --git a/relax/models/qwen_omni/modeling_qwen3_omni/model.py b/relax/models/qwen_omni/modeling_qwen3_omni/model.py new file mode 100644 index 00000000..4fda12bc --- /dev/null +++ b/relax/models/qwen_omni/modeling_qwen3_omni/model.py @@ -0,0 +1,411 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. + +import torch +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.utils import split_deepstack_embs +from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync +from megatron.core import InferenceParams, mpu, tensor_parallel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( + Qwen3OmniMoeThinkerConfig as Qwen3OmniMoeThinkerConfigHF, +) +from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeAudioEncoder as Qwen3OmniMoeAudioEncoderHF, +) +from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeVisionEncoder as Qwen3OmniMoeVisionEncoderHF, +) + +from relax.models.qwen_omni.modeling_qwen3_omni.text_model import Qwen3OmniGPTModel +from relax.models.qwen_omni.modeling_qwen3_omni.transformer_config import Qwen3OmniTransformerConfig +from relax.models.qwen_omni.modeling_qwen3_omni.utils import get_rope_index + + +class Qwen3OmniMoeModel(MegatronModule): + """Qwen3 Omni MoE Thinker Model for multimodal understanding. + + This model supports audio, image, and video inputs in addition to text. + It processes multimodal inputs through separate encoders and combines them + for the language model. + + This is a standalone implementation that does not inherit from other models + to maintain independence from version-specific implementations. + """ + + def __init__( + self, + language_transformer_config: Qwen3OmniTransformerConfig, + language_transformer_layer_spec: ModuleSpec, + audio_transformer_config: Qwen3OmniMoeThinkerConfigHF, + vision_transformer_config: Qwen3OmniMoeThinkerConfigHF, + parallel_output: bool = True, + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + use_audio_in_video: bool = False, + pg_collection=None, + ): + super().__init__(config=language_transformer_config) + + self.pre_process = pre_process + self.post_process = post_process + self.pg_collection = pg_collection + self.add_encoder = add_encoder + self.add_decoder = add_decoder + + self.encoder_hidden_state = None + self.vision_model = None + self.language_model = None + self.image_token_id = language_transformer_config.image_token_id + self.video_token_id = language_transformer_config.video_token_id + self.vision_start_token_id = language_transformer_config.vision_start_token_id + + # This attribute is needed to check if an all-reduce is required + # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. + self.share_embeddings_and_output_weights = False + + self.position_id_per_seconds = language_transformer_config.position_id_per_seconds + self.audio_token_id = language_transformer_config.audio_token_id + self.audio_start_token_id = language_transformer_config.audio_start_token_id + self.use_audio_in_video = use_audio_in_video + self.audio_model = None + + if self.pre_process: + # Initialize audio and vision models with random weights from config + self.audio_model = Qwen3OmniMoeAudioEncoderHF._from_config(audio_transformer_config) + self.vision_model = Qwen3OmniMoeVisionEncoderHF._from_config(vision_transformer_config) + # Ensure HF encoder params are marked for TP grad sync and future assignments are hooked. + hook_hf_module_setattr_for_tp_grad_sync(self.audio_model) + hook_hf_module_setattr_for_tp_grad_sync(self.vision_model) + # Move to device if available + if torch.cuda.is_available(): + self.audio_model = self.audio_model.to("cuda") + self.vision_model = self.vision_model.to("cuda") + + self.language_model = Qwen3OmniGPTModel( + config=language_transformer_config, + transformer_layer_spec=language_transformer_layer_spec, + vocab_size=language_transformer_config.vocab_size, + max_sequence_length=language_transformer_config.language_max_sequence_length, + parallel_output=parallel_output, + position_embedding_type="mrope", + rotary_percent=language_transformer_config.rotary_percent, + pre_process=self.pre_process, + post_process=self.post_process, + rotary_base=language_transformer_config.rotary_base, + fp16_lm_cross_entropy=language_transformer_config.fp16_lm_cross_entropy, + share_embeddings_and_output_weights=language_transformer_config.share_embeddings_and_output_weights, + scatter_embedding_sequence_parallel=False, + pg_collection=pg_collection, + ) + self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights + + def set_input_tensor(self, input_tensor) -> None: + """Set input tensor to be used instead of forward()'s input. + + When the pipeline parallel size > 1, the input tensor is received from + the previous pipeline stage and must be provided to the model via this method. + + Args: + input_tensor (list or torch.Tensor): Input tensor(s) from the previous pipeline stage. + """ + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1, "input_tensor should only be length 1 for Qwen3OmniMoeModel" + + if self.pre_process: + self.encoder_hidden_state = input_tensor[0] + else: + self.language_model.set_input_tensor(input_tensor[0]) + + def freeze( + self, + freeze_language_model: bool, + freeze_vision_model: bool, + freeze_vision_projection: bool, + freeze_audio_model: bool = False, + ): + """Freeze model modules. + + Make specific modules non-trainable by setting requires_grad to False. + + Args: + freeze_language_model (bool): Freeze the language model module. + freeze_vision_model (bool): Freeze the vision model module. + freeze_vision_projection (bool): Freeze the vision projection modules. + freeze_audio_model (bool): Freeze the audio model module. + """ + if freeze_language_model and self.language_model is not None: + for param in self.language_model.parameters(): + param.requires_grad = False + + if freeze_vision_model and self.vision_model is not None: + for param in self.vision_model.parameters(): + param.requires_grad = False + + if freeze_audio_model and self.audio_model is not None: + self.audio_model._freeze_parameters() + + def forward( + self, + input_ids: torch.Tensor, + input_features: torch.Tensor = None, + position_ids: torch.Tensor = None, # can set at dataset + attention_mask: torch.Tensor = None, + feature_attention_mask: torch.Tensor = None, + labels: torch.Tensor = None, + loss_mask: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + pixel_values: torch.Tensor = None, + pixel_values_videos: torch.Tensor = None, + image_grid_thw: torch.Tensor = None, + video_grid_thw: torch.Tensor = None, + image_input_mask: torch.Tensor = None, + video_second_per_grid=None, + ) -> torch.Tensor: + """Forward function of the Qwen3 Omni model. + + Args: + input_ids (torch.Tensor): input text ids [batch, text_seq_len]. + input_features (torch.Tensor): audio features. + position_ids (torch.Tensor): input text position ids [batch, text_seq_len]. + attention_mask (torch.Tensor): attention mask for the language model. + feature_attention_mask (torch.Tensor): attention mask for audio features. + labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. + loss_mask (torch.Tensor): Loss mask. + inference_params (InferenceParams): Inference-time parameters including KV cache. + packed_seq_params (PackedSeqParams): Packed sequence parameters. + extra_block_kwargs (dict): Extra block kwargs. + pixel_values (torch.Tensor): Image pixel values. + pixel_values_videos (torch.Tensor): Video pixel values. + image_grid_thw (torch.Tensor): Image grid dimensions. + video_grid_thw (torch.Tensor): Video grid dimensions. + image_input_mask (torch.Tensor): Image input mask. + video_second_per_grid (torch.Tensor): Seconds per video grid. + + Returns: + output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits. + """ + assert inference_params is None, "not support inference" + + video_start_index = 0 + vision_grid_thw = None + vision_data = None + image_mask = None + video_mask = None + deepstack_feature_lists = None + # position ids is computed within the model + position_ids = None + audio_feature_lengths = None + + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + + if self.pre_process: + # ========================= + # image / Video + # ========================= + if image_grid_thw is not None or video_grid_thw is not None: + if image_grid_thw is not None: + image_mask = image_input_mask + if image_mask is None: + image_mask = (input_ids == self.image_token_id).contiguous() + vision_grid_thw = image_grid_thw + vision_data = pixel_values + video_start_index = image_mask.sum().item() + else: + video_start_index = 0 + + # Handle videos - concatenate if both present + if video_grid_thw is not None: + video_mask = (input_ids == self.video_token_id).contiguous() + if vision_grid_thw is not None: + # Both images and videos present - concatenate + vision_grid_thw = torch.cat([vision_grid_thw, video_grid_thw], dim=0) + vision_data = torch.cat([vision_data, pixel_values_videos], dim=0) + else: + # Only videos present + vision_grid_thw = video_grid_thw + vision_data = pixel_values_videos + + vision_embeds = None + if vision_grid_thw is not None and vision_grid_thw.shape[0] > 0: + vision_outputs = self.vision_model( + hidden_states=vision_data, + grid_thw=vision_grid_thw, + ) + + import transformers + from packaging import version + + if version.parse(transformers.__version__) >= version.parse("5.0.0"): + vision_embeds = vision_outputs.pooler_output + deepstack_feature_lists = vision_outputs.deepstack_features + else: + vision_embeds, deepstack_feature_lists = vision_outputs + + combined_embeddings = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ).clone() # [text_seq_len, b, h_language] + + if vision_embeds is not None: + if video_start_index == 0: + image_embeds = None + video_embeds = vision_embeds + elif video_start_index == vision_embeds.shape[0]: + image_embeds = vision_embeds + video_embeds = None + elif 0 < video_start_index < vision_embeds.shape[0]: + image_embeds = vision_embeds[:video_start_index] + video_embeds = vision_embeds[video_start_index:] + else: + raise ValueError( + f"Expect video token start index in range [0, {vision_embeds.shape[0]}], but got " + f"{video_start_index}" + ) + + if image_embeds is not None: + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + combined_embeddings[image_mask] = image_embeds + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + + if video_embeds is not None: + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + combined_embeddings[video_mask] = video_embeds + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + + # Create visual_pos_masks for deepstack processing + if image_embeds is not None and video_embeds is not None: + visual_pos_masks = image_mask | video_mask + elif image_embeds is not None: + visual_pos_masks = image_mask + elif video_embeds is not None: + visual_pos_masks = video_mask + else: + visual_pos_masks = None + else: + visual_pos_masks = None + + # ========================= + # Audio + # ========================= + if input_features is not None: + audio_mask = (input_ids == self.audio_token_id).contiguous() + if feature_attention_mask is not None: + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + + feature_lens = ( + audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + ) + + # dtype from fp32 to bf16 + audio_outputs = self.audio_model( + input_features.to(next(self.audio_model.parameters()).dtype), + feature_lens=feature_lens, + ) + audio_embeds = audio_outputs.last_hidden_state # [num_audio_tokens, hidden] + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + + combined_embeddings[audio_mask] = audio_embeds + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + + if self.config.sequence_parallel: + combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings) + combined_embeddings = combined_embeddings.contiguous() + else: + combined_embeddings = None + visual_pos_masks = None + + cu_seqlens_padded = None + if packed_seq_params is not None: + if packed_seq_params.cu_seqlens_q_padded is not None: + cu_seqlens_padded = packed_seq_params.cu_seqlens_q_padded + else: + cu_seqlens_padded = packed_seq_params.cu_seqlens_q + + hf_attention_mask = None + if position_ids is None: + input_ids_for_rope_index = input_ids + if cu_seqlens_padded is not None: + + def thd_to_bshd(packed_values: torch.Tensor, cu_seqlens: torch.Tensor): + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seq_len = seqlens.max() + bs = len(cu_seqlens) - 1 + results = packed_values.new_zeros(size=(bs, max_seq_len, *packed_values.shape[2:])) + for i, seqlen in enumerate(seqlens): + results[i, :seqlen] = packed_values[0, cu_seqlens[i] : cu_seqlens[i] + seqlen] + return results + + def bshd_to_thd(unpacked_values: torch.Tensor, cu_seqlens: torch.Tensor): + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + total_len = cu_seqlens[-1] + results = unpacked_values.new_zeros(size=(1, total_len, *unpacked_values.shape[2:])) + for i, seqlen in enumerate(seqlens): + results[0, cu_seqlens[i] : cu_seqlens[i] + seqlen] = unpacked_values[i, :seqlen] + return results + + input_ids_for_rope_index = thd_to_bshd(input_ids, cu_seqlens_padded) + + # ========================= + # RoPE index (audio-aware) + # ========================= + position_ids, _ = get_rope_index( + spatial_merge_size=self.config.spatial_merge_size, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + audio_token_id=self.audio_token_id, + vision_start_token_id=self.vision_start_token_id, + audio_start_token_id=self.audio_start_token_id, + input_ids=input_ids_for_rope_index, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + audio_seqlens=audio_feature_lengths, + attention_mask=hf_attention_mask, + use_audio_in_video=self.use_audio_in_video, + second_per_grids=video_second_per_grid, + position_id_per_seconds=self.position_id_per_seconds, + ) + if cu_seqlens_padded is not None: + position_ids = bshd_to_thd(position_ids.permute(1, 2, 0), cu_seqlens_padded).permute(2, 0, 1) + + deepstack_visual_embeds = deepstack_feature_lists + + # Split visual_pos_masks and deepstack_visual_embeds for sequence parallel / CP + if self.config.sequence_parallel and visual_pos_masks is not None and deepstack_visual_embeds is not None: + if self.pg_collection is not None: + tp_size = self.pg_collection.tp.size() + tp_rank = self.pg_collection.tp.rank() + else: + tp_size = mpu.get_tensor_model_parallel_world_size() + tp_rank = mpu.get_tensor_model_parallel_rank() + visual_pos_masks, deepstack_visual_embeds = split_deepstack_embs( + visual_pos_masks, + deepstack_visual_embeds, + tp_size=tp_size, + tp_rank=tp_rank, + cp_size=1, + cp_rank=0, + sequence_parallel=True, + ) + + output = self.language_model( + input_ids=None, + position_ids=position_ids, # None in encoder + attention_mask=attention_mask, # None in encoder + decoder_input=combined_embeddings, # only not None in the first decoder PP stage + labels=labels, # only not None in the last decoder PP stage + loss_mask=loss_mask, + inference_params=inference_params, # currently always None + packed_seq_params=packed_seq_params, # currently always None + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + **(extra_block_kwargs or {}), + ) + + return output diff --git a/relax/models/qwen_omni/modeling_qwen3_omni/rope.py b/relax/models/qwen_omni/modeling_qwen3_omni/rope.py new file mode 100644 index 00000000..b0719f68 --- /dev/null +++ b/relax/models/qwen_omni/modeling_qwen3_omni/rope.py @@ -0,0 +1,43 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. + + +from typing import List + +import torch +from torch import Tensor +from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeThinkerTextRotaryEmbedding + + +class Qwen3OmniMoeThinkerTextRotaryEmbedding(Qwen3OmniMoeThinkerTextRotaryEmbedding): + """Qwen3-Omni MoE text rotary position embedding.""" + + def forward( + self, position_ids: torch.Tensor, mrope_section: List[int], packed_seq_params=None, **kwargs + ) -> Tensor: + """Forward pass of multimodal RoPE embedding. + + Args: + position_ids (torch.Tensor): A postion_id tensor with shape [3, batchsize, seqlens] + mrope_section (list[int]): Multimodal rope section is for channel dimension of temporal, + height and width in rope calculation. + + Returns: + Tensor: Raw frequency embeddings for Megatron Core (shape: [seq_length, bs, 1, dim]). + Megatron Core will compute cos/sin internally and apply attention_scaling. + """ + # Use fp32 for position indices to avoid precision loss when inv_freq is bf16. + seq = position_ids.to(device=self.inv_freq.device, dtype=torch.float32) + + # if self.seq_len_interpolation_factor is not None: + # seq *= 1 / self.seq_len_interpolation_factor + + # shape (3, bs, dim, 1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, seq.shape[1], -1, 1) + # shape (3, bs, 1, seq_length) + seq_expanded = seq[:, :, None, :].float() + # shape (3, bs, seq_length, dim) + freqs = (inv_freq_expanded @ seq_expanded).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + emb = emb[..., None, :].transpose(0, 1).contiguous() + return emb diff --git a/relax/models/qwen_omni/modeling_qwen3_omni/text_model.py b/relax/models/qwen_omni/modeling_qwen3_omni/text_model.py new file mode 100644 index 00000000..72c020ca --- /dev/null +++ b/relax/models/qwen_omni/modeling_qwen3_omni/text_model.py @@ -0,0 +1,76 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. + + +from typing import Literal, Optional + +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.text_model import Qwen3VLGPTModel +from megatron.bridge.models.transformer_config import TransformerConfig +from megatron.core.transformer.spec_utils import ModuleSpec + +from relax.models.qwen_omni.modeling_qwen3_omni.rope import Qwen3OmniMoeThinkerTextRotaryEmbedding +from relax.models.qwen_omni.modeling_qwen3_omni.transformer_block import Qwen3OmniTransformerBlock + + +class Qwen3OmniGPTModel(Qwen3VLGPTModel): + """Qwen3-Omni GPT model with vision-language capabilities.""" + + def __init__( + self, + config: TransformerConfig, + transformer_layer_spec: ModuleSpec, + vocab_size: int, + max_sequence_length: int, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + share_embeddings_and_output_weights: bool = False, + position_embedding_type: Literal["learned_absolute", "rope", "mrope", "none"] = "learned_absolute", + rotary_percent: float = 1.0, + rotary_base: int = 10000, + rope_scaling: bool = False, + rope_scaling_factor: float = 8.0, + scatter_embedding_sequence_parallel: bool = True, + seq_len_interpolation_factor: Optional[float] = None, + mtp_block_spec: Optional[ModuleSpec] = None, + vp_stage: Optional[int] = None, + pg_collection=None, + ) -> None: + super().__init__( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=fp16_lm_cross_entropy, + parallel_output=parallel_output, + share_embeddings_and_output_weights=share_embeddings_and_output_weights, + position_embedding_type=position_embedding_type, + rotary_percent=rotary_percent, + rotary_base=rotary_base, + rope_scaling=rope_scaling, + rope_scaling_factor=rope_scaling_factor, + scatter_embedding_sequence_parallel=scatter_embedding_sequence_parallel, + seq_len_interpolation_factor=seq_len_interpolation_factor, + mtp_block_spec=mtp_block_spec, + vp_stage=vp_stage, + pg_collection=pg_collection, + ) + + self.rotary_pos_emb = Qwen3OmniMoeThinkerTextRotaryEmbedding(config.hf_text_config) + + self.mrope_section = self.config.mrope_section + assert self.mrope_section is not None, ( + "mrope require mrope_section setting, but we got None from TransformerConfig" + ) + + # rebuild the transformer block + self.decoder = Qwen3OmniTransformerBlock( + config=self.config, + spec=transformer_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + vp_stage=vp_stage, + pg_collection=pg_collection, + ) diff --git a/relax/models/qwen_omni/modeling_qwen3_omni/transformer_block.py b/relax/models/qwen_omni/modeling_qwen3_omni/transformer_block.py new file mode 100644 index 00000000..6a1004b1 --- /dev/null +++ b/relax/models/qwen_omni/modeling_qwen3_omni/transformer_block.py @@ -0,0 +1,27 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. + + +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block import Qwen3VLTransformerBlock + + +try: + import transformer_engine.pytorch as te # noqa: F401 # pylint: disable=unused-import + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +te_checkpoint = None +if HAVE_TE: + pass + + +class Qwen3OmniTransformerBlock(Qwen3VLTransformerBlock): + """Qwen3 Omni Transformer Block extending Qwen3VL functionality. + + This block extends the Qwen3VL transformer block with Omni-specific + features for handling multimodal inputs including audio, images, and + videos. + """ + + pass diff --git a/relax/models/qwen_omni/modeling_qwen3_omni/transformer_config.py b/relax/models/qwen_omni/modeling_qwen3_omni/transformer_config.py new file mode 100644 index 00000000..d0cbf75d --- /dev/null +++ b/relax/models/qwen_omni/modeling_qwen3_omni/transformer_config.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. + + +from dataclasses import dataclass, field +from typing import List, Optional + +from megatron.core.transformer.transformer_config import TransformerConfig +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import Qwen3OmniMoeTextConfig + + +@dataclass +class Qwen3OmniTransformerConfig(TransformerConfig): + """Configuration for Qwen3-VL transformer with vision and language + components.""" + + vocab_size: int = 64000 + language_max_sequence_length: int = 4096 + + patch_size: int = 14 + temporal_patch_size: int = 2 + in_channels: int = 3 + spatial_merge_size: int = 2 + num_position_embeddings: int = 2304 + out_hidden_size: int = 2304 + + apply_rotary_pos_emb_in_fp32: bool = False + deepstack_visual_indexes: List[int] = field(default_factory=lambda: [8, 16, 24]) + fp16_lm_cross_entropy: bool = False + share_embeddings_and_output_weights: bool = False + rotary_percent: float = 1.0 + rotary_base: float = 10000 + + # Multimodal rope section for [temporal, height, width] dimensions + mrope_section: List[int] = field(default_factory=lambda: [24, 20, 20]) + apply_rope_fusion: bool = False + + image_token_id: int = 151655 + video_token_id: int = 151656 + vision_start_token_id: int = 151652 + hf_text_config: Optional[Qwen3OmniMoeTextConfig] = None diff --git a/relax/models/qwen_omni/modeling_qwen3_omni/utils.py b/relax/models/qwen_omni/modeling_qwen3_omni/utils.py new file mode 100644 index 00000000..9445fb9b --- /dev/null +++ b/relax/models/qwen_omni/modeling_qwen3_omni/utils.py @@ -0,0 +1,356 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. + +from typing import Optional + +import torch +from megatron.core.packed_seq_params import PackedSeqParams + + +def _get_feat_extract_output_lengths(input_lengths): + """Computes the output length of the convolutional layers and the output + length of the audio encoder.""" + + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + return output_lengths + + +def get_llm_pos_ids_for_vision( + self, + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[torch.Tensor], + grid_hs: list[torch.Tensor], + grid_ws: list[torch.Tensor], +): + """Generate LLM position IDs for vision tokens. + + Computes position embeddings for vision tokens (images/videos) by creating + 3D position indices (temporal, height, width) based on spatial merge size. + + Args: + self: Instance reference. + start_idx: Starting position index offset. + vision_idx: Index of the vision sample. + spatial_merge_size: Size of spatial merge for grid downsampling. + t_index: List of temporal indices. + grid_hs: List of grid heights. + grid_ws: List of grid widths. + + Returns: + torch.Tensor: Position IDs of shape [3, num_tokens] with temporal, height, width indices. + """ + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten().float() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten().float() + t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().float() + _llm_pos_ids = torch.stack([t_index, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + +def get_rope_index( + spatial_merge_size: int, + image_token_id: int, + video_token_id: int, + audio_token_id: int, + vision_start_token_id: int, + audio_start_token_id: int, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + audio_seqlens: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + second_per_grids: Optional[torch.Tensor] = None, + position_id_per_seconds: int = 1, + packed_seq_params: Optional[PackedSeqParams] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Generate RoPE position indices for multimodal inputs. + + Computes rotary position embeddings (RoPE) indices for a sequence containing + mixed modalities (text, images, videos, audio). Handles temporal, spatial, + and audio-specific position encoding. + + Args: + spatial_merge_size: Size of spatial merge for grid downsampling. + image_token_id: Token ID for image markers. + video_token_id: Token ID for video markers. + audio_token_id: Token ID for audio markers. + vision_start_token_id: Token ID marking start of vision content. + audio_start_token_id: Token ID marking start of audio content. + input_ids: Input token IDs of shape [batch_size, seq_len]. + image_grid_thw: Image grid dimensions [num_images, 3] with (T, H, W). + video_grid_thw: Video grid dimensions [num_videos, 3] with (T, H, W). + audio_seqlens: Audio sequence lengths [num_audios]. + attention_mask: Attention mask indicating valid tokens. + use_audio_in_video: Whether audio is embedded within video tokens. + second_per_grids: Seconds per video grid frame. + position_id_per_seconds: Position ID increment per second. + packed_seq_params: Packed sequence parameters for variable-length sequences. + + Returns: + tuple: (position_ids, mrope_position_deltas) where: + - position_ids: Shape [3, batch_size, seq_len] with temporal, height, width indices. + - mrope_position_deltas: Shape [batch_size, 1] with position delta adjustments. + """ + # VL timestamp split logic (unchanged) + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + if packed_seq_params is not None and attention_mask is None and input_ids is not None: + # Build an attention mask from packed sequence metadata when one is not provided. + # cu_seqlens_q entries are cumulative lengths; their diffs give per-sample lengths. + cu_seqlens = packed_seq_params.cu_seqlens_q + if cu_seqlens is not None and cu_seqlens.numel() >= 2: + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + attention_mask = torch.zeros_like(input_ids, dtype=input_ids.dtype) + max_len = attention_mask.shape[1] + for i, seq_len in enumerate(seq_lens.tolist()): + valid = min(int(seq_len), max_len) + attention_mask[i, :valid] = 1 + else: + # Fallback to a dense mask if packed metadata is missing. + attention_mask = torch.ones_like(input_ids) + + mrope_position_deltas = [] + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None or audio_seqlens is not None + ): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=torch.float, + device=input_ids.device, + ) + image_index, video_index, audio_index = 0, 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + audio_nums = torch.sum(input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums + multimodal_nums = image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums + + for _ in range(multimodal_nums): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + if (image_token_id in input_tokens or video_token_id in input_tokens) and ( + remain_videos > 0 or remain_images > 0 + ): + ed_vision_start = input_tokens.index(vision_start_token_id, st) + else: + ed_vision_start = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio_start = input_tokens.index(audio_start_token_id, st) + else: + ed_audio_start = len(input_tokens) + 1 + min_ed = min(ed_vision_start, ed_audio_start) + + # ---------- text ---------- + text_len = min_ed - st + if text_len > 0: + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + st_idx += text_len + + # ---------- BOS ---------- + # Audio in Video + if min_ed == ed_vision_start and ed_vision_start + 1 == ed_audio_start: + bos_len, eos_len = 2, 2 + else: + bos_len, eos_len = 1, 1 + + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + st_idx += bos_len + + # Audio Only + if min_ed == ed_audio_start: + audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_index]) + llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + llm_pos_ids_list.append(llm_pos_ids) + + st += text_len + bos_len + audio_len + eos_len + audio_index += 1 + remain_audios -= 1 + + # Image Only + elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == image_token_id: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + + t_index = (torch.arange(t) * 1 * position_id_per_seconds).float() + + llm_pos_ids_list_temp = [] + llm_grid_h = h // spatial_merge_size + llm_grid_w = w // spatial_merge_size + h_index = ( + torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten().float() + ) + w_index = ( + torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten().float() + ) + t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().float() + _llm_pos_ids = torch.stack([t_index, h_index, w_index]) + llm_pos_ids_list_temp.append(_llm_pos_ids + st_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list_temp, dim=1) + + llm_pos_ids_list.append(llm_pos_ids) + + image_len = image_grid_thw[image_index].prod() // (spatial_merge_size**2) + st += int(text_len + bos_len + image_len + eos_len) + image_index += 1 + remain_images -= 1 + + # Video Only + elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == video_token_id: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + t_index = ( + torch.arange(t) * second_per_grids[video_index].cpu().float() * position_id_per_seconds + ).float() + + llm_pos_ids_list_temp = [] + llm_grid_h = h // spatial_merge_size + llm_grid_w = w // spatial_merge_size + h_index = ( + torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten().float() + ) + w_index = ( + torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten().float() + ) + t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().float() + _llm_pos_ids = torch.stack([t_index, h_index, w_index]) + llm_pos_ids_list_temp.append(_llm_pos_ids + st_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list_temp, dim=1) + + llm_pos_ids_list.append(llm_pos_ids) + + video_len = video_grid_thw[video_index].prod() // (spatial_merge_size**2) + st += int(text_len + bos_len + video_len + eos_len) + video_index += 1 + remain_videos -= 1 + + # Audio in Video + elif min_ed == ed_vision_start and ed_vision_start + 1 == ed_audio_start: + audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_index]) + audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + + t_index = ( + torch.arange(t) * second_per_grids[video_index].cpu().float() * position_id_per_seconds + ).float() + + llm_pos_ids_list_temp = [] + llm_grid_h = h // spatial_merge_size + llm_grid_w = w // spatial_merge_size + h_index = ( + torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten().float() + ) + w_index = ( + torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten().float() + ) + t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().float() + _llm_pos_ids = torch.stack([t_index, h_index, w_index]) + llm_pos_ids_list_temp.append(_llm_pos_ids + st_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list_temp, dim=1) + + video_llm_pos_ids = llm_pos_ids + + video_data_index, audio_data_index = 0, 0 + while ( + video_data_index < video_llm_pos_ids.shape[-1] + and audio_data_index < audio_llm_pos_ids.shape[-1] + ): + if video_llm_pos_ids[0][video_data_index] <= audio_llm_pos_ids[0][audio_data_index]: + llm_pos_ids_list.append(video_llm_pos_ids[:, video_data_index : video_data_index + 1]) + video_data_index += 1 + else: + llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_data_index : audio_data_index + 1]) + audio_data_index += 1 + if video_data_index < video_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append(video_llm_pos_ids[:, video_data_index : video_llm_pos_ids.shape[-1]]) + if audio_data_index < audio_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_data_index : audio_llm_pos_ids.shape[-1]]) + video_len = video_grid_thw[video_index].prod() // (spatial_merge_size**2) + + st += int(text_len + bos_len + audio_len + video_len + eos_len) + audio_index += 1 + video_index += 1 + remain_videos -= 1 + remain_audios -= 1 + else: + raise (RuntimeError("unexpected error")) + + # ---------- EOS ---------- + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + # tail text + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat([item.float() for item in llm_pos_ids_list], dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids)) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + # fallback (pure text) + # position_ids = attention_mask.float().cumsum(-1) - 1 + # position_ids.masked_fill_(attention_mask == 0, 1) + # position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + # max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + # mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) + + if attention_mask is not None: + position_ids = attention_mask.float().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas diff --git a/relax/models/qwen_omni/qwen3_omni_bridge.py b/relax/models/qwen_omni/qwen3_omni_bridge.py new file mode 100644 index 00000000..c3e330c2 --- /dev/null +++ b/relax/models/qwen_omni/qwen3_omni_bridge.py @@ -0,0 +1,245 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. + + +import torch +import torch.nn.functional as F +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.param_mapping import AutoMapping, GatedMLPMapping, QKVMapping, ReplicatedMapping +from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM +from transformers import Qwen3OmniMoeForConditionalGeneration + +from relax.models.qwen_omni.modeling_qwen3_omni.model import Qwen3OmniMoeModel +from relax.models.qwen_omni.qwen3_omni_provider import Qwen3OmniModelProvider + + +@MegatronModelBridge.register_bridge(source=Qwen3OmniMoeForConditionalGeneration, target=Qwen3OmniMoeModel) +class Qwen3OmniMoEBridge(MegatronModelBridge): + """Megatron Bridge for Qwen3-VL MoE (Mixture of Experts) Conditional + Generation. + + This bridge handles the conversion between HuggingFace Qwen3VLMoEForConditionalGeneration + and Megatron-Core Qwen3VL MoE model formats, including weight mappings and + configuration translation for vision-language MoE models. + + The weight mappings handle: + - Vision model weights (same as dense model) + - Language model MoE layers with expert routing + - Shared embeddings and output layers + - QK layernorm specific to Qwen3 architecture + + This bridge works with any Qwen3VL MoE model size and automatically extracts + the MoE configuration from the HuggingFace model. + + Example: + >>> from megatron.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct") + >>> provider = bridge.to_megatron_provider() + """ + + # copied from https://github.com/fzyzcjy/Megatron-Bridge/blob/6b1b80cdd3f5387e378545399287bf4a21a56fe0/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py#L54 + def __init__(self): + super().__init__() + self.hf_weights_cache = {} + + def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen3OmniModelProvider: + """Create a Qwen3OmniModelProvider from a HuggingFace pretrained MoE + model. + + Args: + hf_pretrained: HuggingFace pretrained VLM MoE model + + Returns: + Qwen3OmniModelProvider configured with the HF MoE model's parameters + """ + # to check + # hf_pretrained.config + hf_config = hf_pretrained.config.thinker_config + text_config = hf_config.text_config + + # Get the model dtype from text config + model_dtype = self.dtype_from_hf(hf_config, default=torch.float32) + + # Set vision config dtype to match the language model dtype + # This ensures vision model parameters are initialized in the same dtype + audio_config = hf_config.audio_config + audio_config.torch_dtype = model_dtype + vision_config = hf_config.vision_config + vision_config.torch_dtype = model_dtype + + head_dim = getattr(text_config, "head_dim", text_config.hidden_size // text_config.num_attention_heads) + provider = Qwen3OmniModelProvider( + num_layers=text_config.num_hidden_layers, + hidden_size=text_config.hidden_size, + ffn_hidden_size=text_config.intermediate_size, # Dense FFN size (for non-MoE layers if any) + moe_ffn_hidden_size=text_config.moe_intermediate_size, # Expert FFN size + num_attention_heads=text_config.num_attention_heads, + num_query_groups=text_config.num_key_value_heads, # GQA configuration + head_dim=head_dim, + kv_channels=head_dim, # Must explicitly set kv_channels for MCore TransformerConfig + init_method_std=text_config.initializer_range, + layernorm_epsilon=text_config.rms_norm_eps, + gated_linear_unit=True, # Qwen3 MoE uses gated linear units + make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(text_config.vocab_size), + rotary_base=getattr(text_config, "rope_theta", 1000000.0), # Default Qwen3 rope theta + share_embeddings_and_output_weights=getattr(text_config, "tie_word_embeddings", False), + vocab_size=text_config.vocab_size, + seq_length=text_config.max_position_embeddings, + fp16=(model_dtype == torch.float16), + bf16=(model_dtype == torch.bfloat16), + params_dtype=model_dtype, + # Qwen3 specific parameters — match Qwen3VLMoEBridge settings + normalization="RMSNorm", # Qwen3 uses RMSNorm (no bias in layernorms) + activation_func=F.silu, # Qwen3 uses SwiGLU (silu + gated_linear_unit) + add_qkv_bias=text_config.attention_bias, # Qwen3 can have bias in QKV + add_bias_linear=False, # Qwen3 has no bias in linear layers (o_proj, MLP, router) + hidden_dropout=0.0, # Qwen3 uses no hidden dropout + qk_layernorm=True, # Qwen3 uses QK layernorm + # MoE specific parameters + num_moe_experts=text_config.num_experts, + moe_router_topk=text_config.num_experts_per_tok, + moe_grouped_gemm=True, + moe_router_load_balancing_type="aux_loss", + moe_aux_loss_coeff=1e-3, + decoder_sparse_step=getattr(text_config, "decoder_sparse_step", 1), # Default to every layer being MoE + mlp_only_layers=getattr(text_config, "mlp_only_layers", []), # Default to all layers using MoE + # Vision configuration + audio_config=audio_config, + vision_config=vision_config, + # Store the original HF text config for RoPE initialization + hf_text_config=text_config, + # Vision-Language token IDs + bos_token_id=getattr(text_config, "bos_token_id", 151643), + eos_token_id=getattr(text_config, "eos_token_id", 151645), + vision_start_token_id=getattr(hf_config, "vision_start_token_id", 151652), + vision_end_token_id=getattr(hf_config, "vision_end_token_id", 151653), + image_token_id=getattr(hf_config, "image_token_id", 151655), + video_token_id=getattr(hf_config, "video_token_id", 151656), + # audio + audio_token_id=hf_config.audio_token_id, + audio_start_token_id=hf_config.audio_start_token_id, + audio_end_token_id=hf_config.audio_end_token_id, + # MRoPE configuration for multimodal position embeddings + mrope_section=getattr(text_config, "rope_scaling", {}).get("mrope_section", [24, 20, 20]), + position_id_per_seconds=hf_config.position_id_per_seconds, + spatial_merge_size=vision_config.spatial_merge_size, + ) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + """Return MegatronMappingRegistry containing parameter mappings for MoE + models. + + The MoE mappings include: + 1. Standard language model mappings (embeddings, layer norms, output) + 2. Vision model mappings (same as dense model) + 3. QKV mappings with QK layernorm + 4. MoE-specific mappings: + - Router weights for expert selection + - Expert MLPs (multiple experts per layer) + - Pre-MLP layernorm + 5. Deepstack visual merger mappings + + Returns: + MegatronMappingRegistry with all MoE parameter mappings + """ + # Language model direct mappings (same as dense model) + # NOTE: Megatron side (left) uses param names from Qwen3OmniMoeModel (no "thinker." prefix), + # HF side (right) uses param names from Qwen3OmniMoeForConditionalGeneration (with "thinker." prefix). + param_mappings = { + # Embeddings and output layers + "language_model.embedding.word_embeddings.weight": "thinker.model.embed_tokens.weight", + "language_model.output_layer.weight": "thinker.lm_head.weight", + "language_model.decoder.final_layernorm.weight": "thinker.model.norm.weight", + # Layer normalization for attention (TE format - fused into linear) + "language_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "thinker.model.layers.*.input_layernorm.weight", + # MoE-specific: pre-MLP layernorm + "language_model.decoder.layers.*.pre_mlp_layernorm.weight": "thinker.model.layers.*.post_attention_layernorm.weight", + # Dense MLP layer norm (for non-MoE layers, i.e. mlp_only_layers) + "language_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "thinker.model.layers.*.post_attention_layernorm.weight", + # Attention output projection + "language_model.decoder.layers.*.self_attention.linear_proj.weight": "thinker.model.layers.*.self_attn.o_proj.weight", + # QK layernorm weights (Qwen3 specific) + "language_model.decoder.layers.*.self_attention.q_layernorm.weight": "thinker.model.layers.*.self_attn.q_norm.weight", + "language_model.decoder.layers.*.self_attention.k_layernorm.weight": "thinker.model.layers.*.self_attn.k_norm.weight", + # MoE router weights + "language_model.decoder.layers.*.mlp.router.weight": "thinker.model.layers.*.mlp.gate.weight", + # MoE router expert bias + "language_model.decoder.layers.*.mlp.router.expert_bias": "thinker.model.layers.*.mlp.gate.e_score_correction_bias", + # Dense MLP down projection (for non-MoE layers, i.e. mlp_only_layers) + "language_model.decoder.layers.*.mlp.linear_fc2.weight": "thinker.model.layers.*.mlp.down_proj.weight", + # Shared expert down projection + "language_model.decoder.layers.*.mlp.shared_experts.linear_fc2.weight": "thinker.model.layers.*.mlp.shared_expert.down_proj.weight", + # Shared expert gate weight + "language_model.decoder.layers.*.mlp.shared_experts.gate_weight": "thinker.model.layers.*.mlp.shared_expert_gate.weight", + } + + mapping_list = [] + + # Convert simple 1:1 mappings to AutoMapping objects + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + # Add special mappings that require parameter transformation + mapping_list.extend( + [ + # Audio and vision model weights are replicated directly (HF encoders) + ReplicatedMapping( + megatron_param="audio_model.**", + hf_param="thinker.audio_tower.**", + ), + ReplicatedMapping( + megatron_param="vision_model.**", + hf_param="thinker.visual.**", + ), + # QKV mapping: Combine separate Q, K, V matrices + QKVMapping( + megatron_param="language_model.decoder.layers.*.self_attention.linear_qkv.weight", + q="thinker.model.layers.*.self_attn.q_proj.weight", + k="thinker.model.layers.*.self_attn.k_proj.weight", + v="thinker.model.layers.*.self_attn.v_proj.weight", + ), + # QKV bias mapping (if attention_bias is True) + QKVMapping( + megatron_param="language_model.decoder.layers.*.self_attention.linear_qkv.bias", + q="thinker.model.layers.*.self_attn.q_proj.bias", + k="thinker.model.layers.*.self_attn.k_proj.bias", + v="thinker.model.layers.*.self_attn.v_proj.bias", + ), + # Expert mappings for TEGroupedMLP + GatedMLPMapping( + megatron_param="language_model.decoder.layers.*.mlp.experts.linear_fc1.weight*", + gate="thinker.model.layers.*.mlp.experts.*.gate_proj.weight", + up="thinker.model.layers.*.mlp.experts.*.up_proj.weight", + ), + AutoMapping( + megatron_param="language_model.decoder.layers.*.mlp.experts.linear_fc2.weight*", + hf_param="thinker.model.layers.*.mlp.experts.*.down_proj.weight", + ), + # Expert mappings for SequentialMLP (used by quantization) + GatedMLPMapping( + megatron_param="language_model.decoder.layers.*.mlp.experts.local_experts.*.linear_fc1.weight", + gate="thinker.model.layers.*.mlp.experts.*.gate_proj.weight", + up="thinker.model.layers.*.mlp.experts.*.up_proj.weight", + ), + AutoMapping( + megatron_param="language_model.decoder.layers.*.mlp.experts.local_experts.*.linear_fc2.weight", + hf_param="thinker.model.layers.*.mlp.experts.*.down_proj.weight", + ), + # Dense MLP gate+up (for non-MoE layers, i.e. mlp_only_layers) + GatedMLPMapping( + megatron_param="language_model.decoder.layers.*.mlp.linear_fc1.weight", + gate="thinker.model.layers.*.mlp.gate_proj.weight", + up="thinker.model.layers.*.mlp.up_proj.weight", + ), + # Shared expert gate+up + GatedMLPMapping( + megatron_param="language_model.decoder.layers.*.mlp.shared_experts.linear_fc1.weight", + gate="thinker.model.layers.*.mlp.shared_expert.gate_proj.weight", + up="thinker.model.layers.*.mlp.shared_expert.up_proj.weight", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) diff --git a/relax/models/qwen_omni/qwen3_omni_provider.py b/relax/models/qwen_omni/qwen3_omni_provider.py new file mode 100644 index 00000000..cf7d333e --- /dev/null +++ b/relax/models/qwen_omni/qwen3_omni_provider.py @@ -0,0 +1,263 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. + +"""Qwen3 VL MoE Model Provider configurations for Megatron-Core. + +This module provides configuration classes for Qwen3-VL MoE (Mixture of Experts) multimodal models, +compatible with HuggingFace's Qwen3-VL-MoE model configurations. +Reference: https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct +""" + +from dataclasses import dataclass, field +from typing import List, Optional + +from megatron.bridge.models.conversion.transformers_compat import rope_theta_from_hf +from megatron.bridge.models.qwen_vl.qwen3_vl_provider import Qwen3VLMoEModelProvider +from megatron.core.models.gpt import GPTModel as MCoreGPTModel +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( + Qwen3OmniMoeAudioEncoderConfig, + Qwen3OmniMoeTextConfig, + Qwen3OmniMoeVisionEncoderConfig, +) + +from relax.models.qwen_omni.modeling_qwen3_omni.model import Qwen3OmniMoeModel + + +@dataclass +class Qwen3OmniModelProvider(Qwen3VLMoEModelProvider): + """Base model provider for Qwen 3 VL MoE Models. Inherits language model + MoE configuration from Qwen3MoEModelProvider. + + Key MoE Parameters (inherited from Qwen3MoEModelProvider): + - num_moe_experts: Number of total experts (default 128) + - moe_router_topk: Number of experts selected per token (default 8) + - moe_router_load_balancing_type: Load balancing strategy (default "aux_loss") + - moe_aux_loss_coeff: Auxiliary loss coefficient (default 1e-3) + - moe_grouped_gemm: Use grouped GEMM for efficiency (default True) + + Note: num_query_groups in parent class corresponds to num_key_value_heads in HF config. + """ + + # Vision configuration using the transformers Qwen3OmniMoeVisionEncoderConfig + # Default configuration matches the standard Qwen3VL vision encoder + # thinker_config: Qwen3OmniMoeThinkerConfig = field(default_factory=lambda: Qwen3OmniMoeThinkerConfig()) + # talker_config: Qwen3OmniMoeTalkerConfig = field(default_factory=lambda: Qwen3OmniMoeTalkerConfig()) + # code2wav_config: Qwen3OmniMoeCode2WavConfig = field(default_factory=lambda: Qwen3OmniMoeCode2WavConfig()) + + audio_config: Qwen3OmniMoeAudioEncoderConfig = field(default_factory=lambda: Qwen3OmniMoeAudioEncoderConfig()) + vision_config: Qwen3OmniMoeVisionEncoderConfig = field(default_factory=lambda: Qwen3OmniMoeVisionEncoderConfig()) + hf_text_config: Optional[Qwen3OmniMoeTextConfig] = None + + pretrained_model_name: str = "Qwen/Qwen3-Omni-30B-A3B-Instruct" + + audio_token_id: int = 151675 + audio_start_token_id: int = 151669 + audio_end_token_id: int = 151670 + use_audio_in_video: bool = False + + # Vision-specific token IDs matching Qwen3VL MoE configuration + # Based on HuggingFace Qwen3-VL-MoE configs + # Token ID for image placeholder in text + image_token_id: int = 151655 + # Token ID for video placeholder in text + video_token_id: int = 151656 + # Token ID marking start of vision content + vision_start_token_id: int = 151652 + # Token ID marking end of vision content + vision_end_token_id: int = 151653 + # BOS token ID for Qwen3-VL models + bos_token_id: int = 151643 + # EOS token ID for Qwen3-VL models + eos_token_id: int = 151645 + + position_id_per_seconds: int = 0 + + head_dim: int = 128 + qk_layernorm: bool = True + attention_softmax_in_fp32: bool = True + attention_dropout: float = 0.0 + + # Override position embedding for multimodal rope + position_embedding_type: str = "mrope" + + # Multimodal rope section for [temporal, height, width] dimensions + # Based on HuggingFace Qwen3-VL config: mrope_section: [24, 20, 20] + mrope_section: List[int] = field(default_factory=lambda: [24, 20, 20]) + + # RoPE theta value specific to Qwen3-VL models + # From HuggingFace config: rope_theta: 5000000 + rotary_base: float = 5000000.0 + spatial_merge_size: int = 2 + temporal_patch_size: int = 2 + patch_size: int = 16 + + # Override to disable scattering embeddings for vision insertion + scatter_embedding_sequence_parallel: bool = False + + # Router configuration + moe_router_pre_softmax: bool = False # Qwen3 specific + moe_router_dtype: str = "fp32" # Use FP32 for router computations + moe_router_score_function: str = "softmax" # Softmax scoring + moe_router_bias_update_rate: float = 0.001 # Router bias update rate + + # MoE optimization settings + moe_permute_fusion: bool = True # Fuse permutation operations + moe_token_dispatcher_type: str = "alltoall" # All-to-all communication + + # Dense layers configuration (some layers may not use MoE) + # Empty list means all layers use MoE, otherwise specify layer indices + mlp_only_layers: List[int] = field(default_factory=list) + + # Decoder sparse step (frequency of MoE layers) + decoder_sparse_step: int = 1 # Every layer is MoE by default + + # Freeze options for fine-tuning scenarios + # Whether to freeze language model weights + freeze_language_model: bool = False + # Whether to freeze vision encoder weights + freeze_vision_model: bool = False + # Whether to freeze vision-to-language projection weights + freeze_vision_projection: bool = False + # Whether to freeze audio encoder weights + freeze_audio_model: bool = False + language_max_sequence_length: int = 2048 + + # QK layernorm is already True in Qwen3MoEModelProvider, no need to redefine + + # These are typically set in the base class but documented here for clarity + persist_layer_norm: bool = True # Persist layer norm for efficiency + bias_activation_fusion: bool = True # Fuse bias and activation + bias_dropout_fusion: bool = True # Fuse bias and dropout + masked_softmax_fusion: bool = False # Don't fuse masked softmax (Qwen specific) + deallocate_pipeline_outputs: bool = True # Deallocate pipeline outputs to save memory + async_tensor_model_parallel_allreduce: bool = True # Async tensor parallel + distribute_saved_activations: bool = False # Don't distribute saved activations + cp_comm_type: str = "p2p" # Point-to-point communication for context parallel + + def _process_thinker_config(self): + self.thinker_config.head_dim = self.thinker_config.text_config.head_dim + self.thinker_config.hidden_size = self.thinker_config.text_config.hidden_size + self.thinker_config.language_max_sequence_length = getattr( + self.thinker_config.text_config, "language_max_sequence_length", 2048 + ) + + # self.thinker_config.patch_size = self.thinker_config.text_config.patch_size + # self.thinker_config.temporal_patch_size = self.thinker_config.text_config.temporal_patch_size + # self.thinker_config.in_channels = self.thinker_config.text_config.in_channels + # self.thinker_config.spatial_merge_size = self.thinker_config.text_config.spatial_merge_size + # self.thinker_config.num_position_embeddings = self.thinker_config.text_config.num_position_embeddings + # self.thinker_config.out_hidden_size = self.thinker_config.text_config.out_hidden_size + # self.thinker_config.apply_rotary_pos_emb_in_fp32 = self.thinker_config.text_config.apply_rotary_pos_emb_in_fp32 + # self.thinker_config.deepstack_visual_indexes = self.thinker_config.text_config.deepstack_visual_indexes + + self.thinker_config.rotary_percent = 1.0 + self.thinker_config.apply_rope_fusion = False + self.thinker_config.position_embedding_type = "mrope" + self.thinker_config.mrope_section = self.thinker_config.text_config.rope_scaling.get( + "mrope_section", [24, 20, 20] + ) + self.thinker_config.rotary_base = rope_theta_from_hf(self.thinker_config.text_config) + + # self.thinker_config.audio_token_id = self.thinker_config.text_config.audio_token_id + # self.thinker_config.audio_start_token_id = self.thinker_config.text_config.audio_start_token_id + # self.thinker_config.audio_end_token_id = self.thinker_config.text_config.audio_end_token_id + + # self.thinker_config.image_token_id = self.thinker_config.text_config.image_token_id + # self.thinker_config.video_token_id = self.thinker_config.text_config.video_token_id + # self.thinker_config.vision_start_token_id = self.thinker_config.text_config.vision_start_token_id + # self.thinker_config.vision_end_token_id = self.thinker_config.text_config.vision_end_token_id + + self.thinker_config.bos_token_id = getattr(self.thinker_config.text_config, "bos_token_id", 151643) + self.thinker_config.eos_token_id = getattr(self.thinker_config.text_config, "eos_token_id", 151645) + + self.thinker_config.qk_layernorm = True + self.thinker_config.attention_softmax_in_fp32 = True + self.thinker_config.attention_dropout = 0.0 + + self.thinker_config.moe_router_pre_softmax = False + self.thinker_config.moe_router_dtype = "fp32" + self.thinker_config.moe_router_score_function = "softmax" + self.thinker_config.moe_router_bias_update_rate = 0.001 + + self.thinker_config.moe_permute_fusion = True + self.thinker_config.moe_token_dispatcher_type = "alltoall" + + self.thinker_config.mlp_only_layers = self.thinker_config.text_config.mlp_only_layers + self.thinker_config.decoder_sparse_step = self.thinker_config.text_config.decoder_sparse_step + + # to check freeze + # self.thinker_config.freeze_language_model = self.thinker_config.text_config.freeze_language_model + # self.thinker_config.freeze_vision_model = self.thinker_config.text_config.freeze_vision_model + # self.thinker_config.freeze_vision_projection = self.thinker_config.text_config.freeze_vision_projection + self.thinker_config.language_max_sequence_length = 2048 + + self.thinker_config.persist_layer_norm = True + self.thinker_config.bias_activation_fusion = True + self.thinker_config.bias_dropout_fusion = True + self.thinker_config.masked_softmax_fusion = False + self.thinker_config.deallocate_pipeline_outputs = True + self.thinker_config.async_tensor_model_parallel_allreduce = True + self.thinker_config.distribute_saved_activations = False + self.thinker_config.cp_comm_type = "p2p" + + def finalize(self) -> None: + if self.tensor_model_parallel_size > 1: + self.sequence_parallel = True + + super().finalize() + + def provide(self, pre_process=None, post_process=None, vp_stage=None): + """Provide a Qwen3VL MoE model instance with vision and language + components.""" + # self._process_thinker_config() + language_transformer_config = self + + # Create vision transformer config - placeholder for future use + # vision_transformer_config = deepcopy(self) + audio_config_hf = self.audio_config + vision_config_hf = self.vision_config + + language_transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=self.num_moe_experts, + moe_grouped_gemm=True, + qk_layernorm=self.qk_layernorm, + # fp8=False, + # normalization="RMSNorm", + ) + + # reuse Qwen3OmniMoeModel for MoE model but replace the language model with MoE language model + model = Qwen3OmniMoeModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=language_transformer_layer_spec, + audio_transformer_config=audio_config_hf, + vision_transformer_config=vision_config_hf, + pre_process=pre_process, + post_process=post_process, + use_audio_in_video=self.use_audio_in_video, + pg_collection=getattr(self, "_pg_collection", None), + ) + + # Apply freeze options if any are enabled for fine-tuning + if self.freeze_language_model or self.freeze_vision_model or self.freeze_vision_projection: + model.freeze( + freeze_language_model=self.freeze_language_model, + freeze_vision_model=self.freeze_vision_model, + freeze_vision_projection=self.freeze_vision_projection, + freeze_audio_model=self.freeze_audio_model, + ) + + return model + + def provide_language_model(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: + """Provide just the language MoE model component without vision. + + Args: + pre_process: Whether this is the first stage in pipeline parallelism + post_process: Whether this is the last stage in pipeline parallelism + vp_stage: Virtual pipeline stage number + + Returns: + MCoreGPTModel instance (MoE language model only) + """ + # Use parent class to create standard MoE language model + return super().provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) diff --git a/relax/utils/arguments.py b/relax/utils/arguments.py index 911787be..92fbd5f9 100644 --- a/relax/utils/arguments.py +++ b/relax/utils/arguments.py @@ -10,6 +10,7 @@ from relax.backends.sglang.arguments import sglang_parse_args from relax.backends.sglang.arguments import validate_args as sglang_validate_args +from relax.utils import device as device_utils from relax.utils.logging_utils import get_logger from relax.utils.training.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list @@ -62,7 +63,7 @@ def add_serve_arguments(parser): parser.add_argument( "--checkpoint-engine-backend", type=str, - default="nccl", + default=device_utils.get_dist_backend(), help=("Backend for checkpoint engine."), ) parser.add_argument( @@ -184,7 +185,7 @@ def add_cluster_arguments(parser): ), ) - reset_arg(parser, "--distributed-backend", type=str, default="nccl") + reset_arg(parser, "--distributed-backend", type=str, default=device_utils.get_dist_backend()) reset_arg(parser, "--distributed-timeout-minutes", type=int, default=30) return parser @@ -681,6 +682,33 @@ def add_data_arguments(parser): default=10000, help="Buffer size for streaming dataset.", ) + parser.add_argument( + "--prefetch-chunk-size", + type=int, + default=32, + help="Number of samples to dispatch to the thread-pool in each prefetch round. " + "Larger values increase throughput but also memory pressure. Only effective when " + "--use-streaming-dataset is set and the dataset contains multimodal data.", + ) + parser.add_argument( + "--prefetch-max-cached", + type=int, + default=256, + help="Maximum number of pre-loaded samples kept in the prefetch cache. " + "When the cache is full the background prefetch thread pauses until consumers " + "free space. Set to 0 to disable prefetching. Only effective when " + "--use-streaming-dataset is set and the dataset contains multimodal data.", + ) + parser.add_argument( + "--prefetch-num-workers", + type=int, + default=1, + help="Number of parallel worker threads inside the prefetch buffer for " + "I/O-bound media decoding (video/image). Set to 1 to serialise all " + "decoding (safest for FFmpeg which is not fully thread-safe). " + "Higher values increase parallelism but may trigger EAGAIN errors " + "on some platforms. Only effective when prefetching is enabled.", + ) # TODO: maybe add an num_epoch and calculate the num_rollout from buffer parser.add_argument( "--num-rollout", @@ -815,7 +843,6 @@ def add_data_arguments(parser): "for true parallelism without GIL contention." ), ) - parser.add_argument("--metadata-key", type=str, default="metadata", help="JSON dataset key") parser.add_argument( "--tool-key", @@ -1530,12 +1557,15 @@ def add_debug_arguments(parser): parser.add_argument( "--memory-snapshot-dir", type=str, - default=".", + default=None, + help=("Directory for memory snapshot dumps. Defaults to traces//memory_snapshot."), ) parser.add_argument( "--memory-snapshot-num-steps", type=int, default=None, + help="Number of rollout steps after which to dump the memory snapshot. " + "For example, --memory-snapshot-num-steps 3 dumps after step 2 (0-indexed).", ) parser.add_argument( "--profile-target", @@ -1905,6 +1935,16 @@ def add_autoscaler_arguments(parser): default=None, help="Path to the YAML config for custom function arguments.", ) + parser.add_argument( + "--normalize-bbox", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Convert model-output bbox coordinates from normalized [0, 1000] to absolute pixels. " + "Required for Qwen-VL/Qwen2-VL/Qwen3-VL (default True). " + "Set --no-normalize-bbox for Qwen2.5-VL which outputs absolute pixel coordinates." + ), + ) reset_arg(parser, "--padded-vocab-size", type=int, default=None) return parser @@ -2018,6 +2058,10 @@ def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]: def slime_validate_args(args): + # Backward compatibility: old scripts may pass --enable-gloo-process-groups + if not hasattr(args, "use_gloo_process_groups"): + args.use_gloo_process_groups = getattr(args, "enable_gloo_process_groups", False) + args.eval_datasets = _resolve_eval_datasets(args) if args.max_staleness < 0: diff --git a/relax/utils/checkpoint_write_patch.py b/relax/utils/checkpoint_write_patch.py index c573a7a9..feb161c3 100644 --- a/relax/utils/checkpoint_write_patch.py +++ b/relax/utils/checkpoint_write_patch.py @@ -36,6 +36,7 @@ import torch +from relax.utils import device as device_utils from relax.utils.logging_utils import get_logger @@ -206,8 +207,13 @@ def patch_checkpoint_write(): This function is idempotent — calling it multiple times is safe. """ + from megatron.core.dist_checkpointing.strategies.filesystem_async import FileSystemWriterAsync + global _patched - if _patched: + # NOTE(wuhuan): the latest Megatron-LM of 20260506 use write_preloaded_data_multithread instead of + # write_preloaded_data_multiproc, which has solved this issue. + can_patch = hasattr(FileSystemWriterAsync, "write_preloaded_data_multiproc") + if _patched or not can_patch: return _patch_write_preloaded_data_multiproc() @@ -245,10 +251,10 @@ def _patched_write_preloaded_data_multiproc( # cause SIGSEGV. Use threaded parallel writes instead — all tensors # are already on CPU so the I/O releases the GIL and threads achieve # real parallelism without duplicating the CUDA context. - cuda_initialised = torch.cuda.is_available() and torch.cuda.is_initialized() + cuda_initialised = device_utils.is_available() and device_utils.is_initialized() if cuda_initialised: _logger.debug( - f"rank: {rank}, CUDA initialised – using threaded parallel " + f"rank: {rank}, device initialised – using threaded parallel " f"(no-fork) checkpoint write for {len(write_buckets)} buckets" ) write_results_or_exc = _write_buckets_threaded(transform_list, use_msc, write_buckets) @@ -285,7 +291,7 @@ def _patched_schedule_async_call(self, async_req): if async_req.async_fn is None: return # nothing to do - cuda_initialised = torch.cuda.is_available() and torch.cuda.is_initialized() + cuda_initialised = device_utils.is_available() and device_utils.is_initialized() if not cuda_initialised: # CUDA not initialised — safe to use the original fork path. return _original_schedule(self, async_req) @@ -301,7 +307,7 @@ def _patched_schedule_async_call(self, async_req): rank = torch.distributed.get_rank() start_sync = time() - torch.cuda.synchronize() + device_utils.synchronize() end_sync = time() _logger.debug(f"rank: {rank}, takes {end_sync - start_sync} to finish D2H ") diff --git a/relax/utils/data/data.py b/relax/utils/data/data.py index 517a3fea..bc254b7f 100644 --- a/relax/utils/data/data.py +++ b/relax/utils/data/data.py @@ -1,17 +1,13 @@ # Copyright (c) 2026 Relax Authors. All Rights Reserved. import random -import re - -import ray from relax.utils.data.data_utils import ( BaseDataset, filter_long_prompts, read_file, ) -from relax.utils.timer import Timer -from relax.utils.types import MultimodalTypes, Sample +from relax.utils.types import Sample __all__ = ["Dataset", "BaseDataset"] @@ -22,105 +18,6 @@ logger = get_logger(__name__) -def filter_long_prompt(origin_samples: list[Sample], tokenizer, processor, max_length: int | None) -> list[Sample]: - if max_length is None: - return origin_samples - - if not isinstance(origin_samples[0].prompt, str): - logger.warning( - "Skipping max_length check for list prompt. Set apply_chat_template=True to enable length filtering." - ) - return origin_samples - - if processor: - filtered_samples = [] - for sample in origin_samples: - from relax.utils.data.processing_utils import process_vision_info - - multimodal_inputs = process_vision_info(sample.prompt, processor) - processor_output = processor(text=sample.prompt, **multimodal_inputs) - input_ids = processor_output["input_ids"][0] - if len(input_ids) <= max_length: - filtered_samples.append(sample) - else: - prompts = [sample.prompt for sample in origin_samples] - input_ids_list = tokenizer(prompts, add_special_tokens=False)["input_ids"] - filtered_samples = [ - sample - for sample, input_ids in zip(origin_samples, input_ids_list, strict=True) - if len(input_ids) <= max_length - ] - - logger.info(f"Filtered {len(origin_samples) - len(filtered_samples)} samples longer than max_length={max_length}.") - - return filtered_samples - - -def _build_messages(data: dict, prompt_key: str, as_conversation: bool, multimodal_keys: dict = None): - prompt = data.get(prompt_key) - - if isinstance(prompt, str): - # If prompt is a string and we don't apply chat template, return the prompt as is. - if not as_conversation: - return prompt - else: - prompt = [{"role": "user", "content": prompt}] - - if multimodal_keys: - # Build mapping: placeholder -> (MultimodalType, content_list) - multimodals = {} - for type_name, data_key in multimodal_keys.items(): - mt = MultimodalTypes.get(type_name) - if mt: - multimodal_data = data.get(data_key) - if multimodal_data is not None: - multimodals[mt.placeholder] = (mt, list(multimodal_data)) - - pattern = "(" + "|".join(re.escape(p) for p in multimodals.keys()) + ")" - - for message in prompt: - if isinstance(message["content"], str): - content_list = [] - for segment in re.split(pattern, message["content"]): - if not segment: - continue - if segment in multimodals: - mt, content = multimodals[segment] - assert len(content) > 0, ( - f"Not enough {mt.name} data: more '{mt.placeholder}' placeholders in prompt " - f"than {mt.name}s provided in data" - ) - content_list.append({"type": mt.name, mt.name: content.pop(0)}) - else: - content_list.append({"type": "text", "text": segment}) - message["content"] = content_list - - elif isinstance(message["content"], list): - # TODO: handle more general cases. where message['content'] is a dict and contains multiple types of content. - # e.g. - # "content": [ - # { - # "type": "image", - # "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", - # }, - # {"type": "text", "text": "Describe this image."}, - # ], - logger.warning("message['content'] is a list of dicts, no processing will be done.") - continue - else: - raise ValueError( - f"Unsupported content type: {type(message['content'])}, expected str or list of dicts" - ) - - for placeholder, (mt, remaining) in multimodals.items(): - assert len(remaining) == 0, ( - f"Multimodal data count mismatch: {len(remaining)} more {mt.name}(s)" - f"than '{placeholder}' placeholders in prompt" - ) - - return prompt - - class Dataset(BaseDataset): """Eager-loading dataset that loads all data into memory at initialization. @@ -216,17 +113,3 @@ def get_minimum_num_micro_batch_size(total_lengths, max_tokens_per_gpu): batches.append(length) return len(batches) - - -def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size): - assert len(rollout_data_ref) == dp_size - rollout_data = ray.get(rollout_data_ref[dp_rank].inner) - - partition = rollout_data.pop("partition") - total_lengths = rollout_data["total_lengths"] - - # save the seqlen of the whole rollout batch - Timer().seq_lens = total_lengths - rollout_data["total_lengths"] = [total_lengths[i] for i in partition] - - return rollout_data diff --git a/relax/utils/data/data_utils.py b/relax/utils/data/data_utils.py index b20170f2..4236a8a9 100644 --- a/relax/utils/data/data_utils.py +++ b/relax/utils/data/data_utils.py @@ -128,7 +128,7 @@ def build_messages( if multimodals: pattern = "(" + "|".join(re.escape(p) for p in multimodals.keys()) + ")" - + built_prompt = [] for message in prompt: if isinstance(message["content"], str): content_list = [] @@ -146,16 +146,24 @@ def build_messages( content_list.append({"type": mt.name, mt.name: content.pop(0)}) else: content_list.append({"type": "text", "text": segment}) - message["content"] = content_list + built_message = dict(message) + built_message["content"] = content_list + built_prompt.append(built_message) elif isinstance(message["content"], list): - # Already processed, skip - logger.warning("message['content'] is a list of dicts, no processing will be done.") - continue + # Pre-structured content: count multimodal items so the + # remain_data check below doesn't false-positive. + for item in message["content"]: + item_type = item.get("type") + if item_type in remain_data: + remain_data[item_type] -= 1 + built_prompt.append(message) else: raise ValueError( f"Unsupported content type: {type(message['content'])}, expected str or list of dicts" ) + prompt = built_prompt + if any(v > 0 for v in remain_data.values()): raise RuntimeError( f"placeholder lost! The number of remain mutimodal data is {remain_data}. Please check your dataset prompt." @@ -453,6 +461,7 @@ def resolve_path_plan(path: Any) -> tuple[list[str], Optional[slice]]: def _build_reader_for_path(path: str): + path, row_slice = parse_generalized_path(path) if not os.path.exists(path): raise FileNotFoundError(f"Prompt dataset path '{path}' does not exist.") @@ -470,7 +479,10 @@ def jsonl_reader(p): logger.warning(f"JSON decode error at line {line_num}: {e}") continue - return jsonl_reader(path) + reader = jsonl_reader(path) + if row_slice is not None: + reader = itertools.islice(reader, row_slice.start, row_slice.stop, row_slice.step) + return reader if path.endswith(".parquet"): if pq is None: @@ -486,7 +498,10 @@ def parquet_reader(p): for i in range(pf.metadata.num_row_groups): yield from pf.read_row_group(i).to_pylist() - return parquet_reader(path) + reader = parquet_reader(path) + if row_slice is not None: + reader = itertools.islice(reader, row_slice.start, row_slice.stop, row_slice.step) + return reader raise ValueError(f"Unsupported file format: {path}. Supported formats are .jsonl and .parquet.") diff --git a/relax/utils/data/processing_utils.py b/relax/utils/data/processing_utils.py index c1df07c0..dff25495 100644 --- a/relax/utils/data/processing_utils.py +++ b/relax/utils/data/processing_utils.py @@ -3,6 +3,8 @@ import asyncio import base64 import io +import json +import os import tempfile from concurrent.futures import ThreadPoolExecutor @@ -31,7 +33,19 @@ def load_tokenizer(name_or_path: str, **kwargs): - return AutoTokenizer.from_pretrained(name_or_path, **kwargs) + tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) + # Multimodal models like Qwen3-Omni ship the chat template in a standalone + # chat_template.json (loaded by AutoProcessor) rather than tokenizer_config.json, + # so AutoTokenizer leaves chat_template unset. Backfill from the sidecar file. + if getattr(tokenizer, "chat_template", None) is None and os.path.isdir(name_or_path): + chat_template_path = os.path.join(name_or_path, "chat_template.json") + if os.path.isfile(chat_template_path): + with open(chat_template_path) as f: + chat_template = json.load(f).get("chat_template") + if chat_template: + tokenizer.chat_template = chat_template + logger.info(f"Loaded chat_template from {chat_template_path}") + return tokenizer def build_processor_kwargs(multimodal_inputs: dict | None = None) -> dict: diff --git a/relax/utils/data/stream_dataloader.py b/relax/utils/data/stream_dataloader.py index 9f62d365..8972e3e2 100644 --- a/relax/utils/data/stream_dataloader.py +++ b/relax/utils/data/stream_dataloader.py @@ -13,6 +13,8 @@ from transfer_queue.dataloader.streaming_dataloader import StreamingDataLoader from transfer_queue.dataloader.streaming_dataset import StreamingDataset +from relax.utils import device as device_utils + logger = logging.getLogger(__name__) @@ -310,9 +312,9 @@ def get_data_from_transfer_queue( # will receive the real data via broadcast. rollout_data = [None, None] - # Use an explicit CUDA device so the communication backend (e.g. NCCL) - # can bind to a known CUDA context. - cuda_dev = torch.device(f"cuda:{torch.cuda.current_device()}") + # Use an explicit device so the communication backend (e.g. NCCL) + # can bind to a known device context. + cuda_dev = device_utils.make_current_torch_device() # --- Extract rollout_routed_experts BEFORE broadcast_object_list --- # broadcast_object_list uses pickle for the entire payload. When @@ -426,12 +428,12 @@ def get_data_from_transfer_queue( def post_process_rollout_data(args, rollout_data): # move tokens/loss_masks to GPU in-place as a list of tensors (downstream # code in this module expects lists of sequence tensors for packing) - from relax.backends.megatron.cp_utils import slice_log_prob_with_cp + from relax.backends.megatron.cp_utils import maybe_padded_total_lengths, slice_log_prob_with_cp - cuda_dev = torch.device(f"cuda:{torch.cuda.current_device()}") - rollout_data["tokens"] = [torch.tensor(t, dtype=torch.long, device=cuda_dev) for t in rollout_data["tokens"]] + cuda_dev = device_utils.make_current_torch_device() + rollout_data["tokens"] = [torch.as_tensor(t, dtype=torch.long, device=cuda_dev) for t in rollout_data["tokens"]] rollout_data["loss_masks"] = [ - torch.tensor(t, dtype=torch.int, device=cuda_dev) for t in rollout_data["loss_masks"] + torch.as_tensor(t, dtype=torch.int, device=cuda_dev) for t in rollout_data["loss_masks"] ] if "multimodal_train_inputs" in rollout_data: # Move multimodal training tensors to GPU in advance. @@ -459,17 +461,24 @@ def _to_cuda(v): rollout_data["max_seq_lens"] = [max_seq_len] * len(rollout_data["tokens"]) + padded_total_lengths = maybe_padded_total_lengths( + rollout_data["total_lengths"], + args.qkv_format, + "multimodal_train_inputs" in rollout_data, + ) + for key in ["rollout_log_probs", "teacher_log_probs"]: if key not in rollout_data: continue rollout_data[key] = [ - torch.tensor( + torch.as_tensor( slice_log_prob_with_cp( log_prob, total_length, response_length, args.qkv_format, rollout_data["max_seq_lens"][i] if args.qkv_format == "bshd" else None, + padded_total_length=padded_total_lengths[i] if padded_total_lengths is not None else None, ), device=cuda_dev, dtype=torch.float32, @@ -517,6 +526,7 @@ def _to_cuda(v): response_length, args.qkv_format, rollout_data["max_seq_lens"][i] if args.qkv_format == "bshd" else None, + padded_total_length=padded_total_lengths[i] if padded_total_lengths is not None else None, ) topk_tensors.append(topk_tensor) @@ -526,6 +536,6 @@ def _to_cuda(v): from tensordict.tensorclass import NonTensorData rollout_data["rollout_routed_experts"] = [ - torch.tensor(r.data if isinstance(r, NonTensorData) else r, dtype=torch.long, device=cuda_dev) + torch.as_tensor(r.data if isinstance(r, NonTensorData) else r, dtype=torch.long, device=cuda_dev) for r in rollout_data["rollout_routed_experts"] ] diff --git a/relax/utils/data/streaming_dataset.py b/relax/utils/data/streaming_dataset.py index 97c04a5a..35d4ea20 100644 --- a/relax/utils/data/streaming_dataset.py +++ b/relax/utils/data/streaming_dataset.py @@ -28,8 +28,11 @@ import json import os import random +import threading +import time from bisect import bisect_right from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor from typing import Any, Iterator, Optional @@ -57,6 +60,7 @@ "CompositeStreamingReader", "SampleBuffer", "IndexManager", + "PrefetchBuffer", ] @@ -441,6 +445,241 @@ def load_state(self, state: dict) -> None: self.reset(position=position, epoch_id=epoch_id) +class PrefetchBuffer: + """Background prefetch buffer for multimodal data loading. + + Run a background thread that pre-loads and pre-processes samples + (including heavy video/image I/O) **in the exact order** they will be consumed. + + Key design: + - ``set_index_order(indices)`` is called once (at ``shuffle`` time) with + the **entire** upcoming index sequence. The background thread starts + fetching immediately, well before ``get_batch`` is called. + - ``get(idx)`` pops from the cache (near-zero latency on hit) or falls + back to a synchronous single-sample fetch on miss. + - The cache is bounded by ``max_cached``; when full the prefetch thread + pauses until consumers free space via ``get()`` calls. + - A ``ThreadPoolExecutor`` is used inside the prefetch thread to + parallelize video/image decoding across multiple files within a chunk, + since PyAV/FFmpeg releases the GIL during C-level decoding. + + Lifecycle:: + + buf = PrefetchBuffer(process_fn, chunk_size=16, max_cached=256, num_workers=4) + buf.set_index_order([3, 7, 1, 5, ...]) # triggers background loading + sample = buf.get(3) # instant cache hit + """ + + def __init__( + self, + process_fn, + chunk_size: int = 32, + max_cached: int = 256, + num_workers: int = 4, + ): + """Initialize the prefetch buffer. + + Args: + process_fn: ``fn(idx: int) -> Optional[Sample]`` — load and + process a single sample by index. + chunk_size: Number of indices to submit to the thread-pool at + a time inside the prefetch loop. + max_cached: Maximum number of samples to keep in the cache + before the prefetch thread pauses. + num_workers: Number of parallel workers in the internal + ``ThreadPoolExecutor`` for I/O-bound decoding. + """ + self._process_fn = process_fn + self._chunk_size = chunk_size + self._max_cached = max_cached + self._num_workers = num_workers + + # Thread-safe cache: idx -> Optional[Sample] + self._cache: dict[int, Optional[Sample]] = {} + self._lock = threading.Lock() + + # Ordered index sequence set by set_index_order + self._indices: list[int] = [] + self._pos: int = 0 + + # Flow control: cleared when cache is full, set when space is freed + self._space_available = threading.Event() + self._space_available.set() + + # Stats + self._prefetch_hits = 0 + self._prefetch_misses = 0 + + # Thread lifecycle + self._stop = threading.Event() + self._thread: Optional[threading.Thread] = None + + logger.info( + f"PrefetchBuffer created: max_cached={max_cached}, chunk_size={chunk_size}, num_workers={num_workers}" + ) + + # -- Public API -------------------------------------------------------- + + def set_index_order(self, indices: list[int]) -> None: + """Reset the cache and start prefetching in *indices* order. + + Called at the beginning of each epoch (from + ``StreamingDataset.shuffle``) with the full upcoming index sequence. + The prefetch thread starts loading immediately so that later ``get()`` + calls hit the cache. + """ + # Stop any running prefetch thread + self._stop.set() + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=10) + if self._thread.is_alive(): + logger.warning("Previous prefetch thread did not stop within 10s; it will exit on its own stop-event") + + with self._lock: + self._cache.clear() + self._indices = list(indices) + self._pos = 0 + + # Create a fresh stop-event for the new thread so the old thread + # (if still draining) keeps seeing its own set() signal and exits. + self._stop = threading.Event() + self._space_available.set() + stop_event = self._stop + self._thread = threading.Thread(target=self._run, args=(stop_event,), daemon=True, name="prefetch-worker") + self._thread.start() + logger.info(f"PrefetchBuffer: started prefetching {len(indices)} samples") + + def get(self, idx: int) -> Optional[Sample]: + """Return the sample for *idx*. + + Pops from the prefetch cache on hit. On miss, performs a blocking + single-index fetch via ``process_fn``. + """ + with self._lock: + if idx in self._cache: + sample = self._cache.pop(idx) + self._prefetch_hits += 1 + # Signal prefetch thread that space is available + self._space_available.set() + return sample + + # Cache miss — synchronous fallback + self._prefetch_misses += 1 + try: + return self._process_fn(idx) + except Exception: + logger.exception(f"Prefetch fallback failed for index {idx}") + return None + + def stop(self) -> None: + """Signal the prefetch thread to stop.""" + self._stop.set() + # Unblock if waiting on space + self._space_available.set() + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=15) + if self._thread.is_alive(): + logger.warning("Prefetch thread did not terminate within 15s") + + def clear(self) -> None: + """Clear the cache and reset position (without stopping the thread).""" + with self._lock: + self._cache.clear() + self._space_available.set() + + @property + def hit_rate(self) -> float: + """Return prefetch cache hit rate.""" + total = self._prefetch_hits + self._prefetch_misses + return self._prefetch_hits / total if total > 0 else 0.0 + + @property + def cache_size(self) -> int: + """Return current cache size.""" + with self._lock: + return len(self._cache) + + # -- Background thread ------------------------------------------------- + + def _run(self, stop_event: threading.Event) -> None: + """Background prefetch loop. + + Iterates through ``self._indices`` in order, loading chunks in parallel + via a ``ThreadPoolExecutor``. Pauses when the cache is full and + resumes when ``get()`` frees space. + + Args: + stop_event: Thread-local stop signal. Each thread receives its + own ``Event`` so that ``set_index_order`` can replace + ``self._stop`` for the next thread without accidentally + un-stopping this one. + """ + _MAX_SUBMIT_RETRIES = 3 + consecutive_failures = 0 + + with ThreadPoolExecutor(max_workers=self._num_workers, thread_name_prefix="pf") as pool: + while not stop_event.is_set(): + # 1. Get next chunk of indices + with self._lock: + if self._pos >= len(self._indices): + break # All indices have been dispatched + chunk = self._indices[self._pos : self._pos + self._chunk_size] + self._pos += len(chunk) + + # 2. Filter out indices already in cache + with self._lock: + to_fetch = [i for i in chunk if i not in self._cache] + if not to_fetch: + continue + + # 3. Wait until cache has room for this chunk + while not stop_event.is_set(): + with self._lock: + if len(self._cache) + len(to_fetch) <= self._max_cached: + break + self._space_available.clear() + # Wait for consumers to pop entries + if not self._space_available.wait(timeout=0.1): + continue + + if stop_event.is_set(): + return + + # 4. Parallel-fetch all samples in the chunk + try: + futures = {idx: pool.submit(self._process_fn, idx) for idx in to_fetch} + results = {} + for idx, fut in futures.items(): + try: + results[idx] = fut.result(timeout=120) + except Exception: + logger.warning(f"Prefetch failed for index {idx}", exc_info=True) + results[idx] = None + consecutive_failures = 0 + except Exception: + consecutive_failures += 1 + logger.exception( + f"Prefetch chunk submission failed (attempt {consecutive_failures}/{_MAX_SUBMIT_RETRIES})" + ) + if consecutive_failures >= _MAX_SUBMIT_RETRIES: + logger.error("Prefetch thread aborting after too many consecutive submission failures") + break + time.sleep(0.5) + with self._lock: + self._pos -= len(chunk) + continue + + # 5. Store results in cache + with self._lock: + for idx, sample in results.items(): + self._cache[idx] = sample + + logger.info( + f"Prefetch thread finished. Hit rate: {self.hit_rate:.1%} " + f"(hits={self._prefetch_hits}, misses={self._prefetch_misses})" + ) + + class StreamingDataset(BaseDataset): """Memory-efficient streaming dataset with on-demand loading. @@ -449,6 +688,7 @@ class StreamingDataset(BaseDataset): Features: - Lazy loading: Only loads data when accessed - LRU caching: Caches recently accessed samples + - Background prefetching: Pre-loads multimodal data in background thread - Shuffle support: Epoch-based reproducible shuffling - Filter support: Length filtering done at access time @@ -481,7 +721,9 @@ def __init__( apply_chat_template_kwargs: Optional[dict] = None, use_audio_in_video: bool = False, buffer_size: int = 10000, - prefetch_size: int = 100, + prefetch_chunk_size: int = 32, + prefetch_max_cached: int = 256, + prefetch_num_workers: int = 1, multimodal_config: MultimodalConfig = None, ): """Initialize the streaming dataset. @@ -501,8 +743,15 @@ def __init__( apply_chat_template: Whether to apply chat template apply_chat_template_kwargs: Additional kwargs for chat template use_audio_in_video: Whether to extract audio from video files for multimodal processing - buffer_size: Maximum samples to cache - prefetch_size: Number of samples to prefetch (not implemented yet) + buffer_size: Maximum samples to cache in LRU buffer + prefetch_chunk_size: Number of samples dispatched to the thread-pool + in each prefetch round + prefetch_max_cached: Maximum number of pre-loaded samples in the + prefetch cache. Set to 0 to disable prefetching. + prefetch_num_workers: Number of parallel worker threads inside the + prefetch buffer for I/O-bound media decoding. Set to 1 to + serialise decoding (avoids FFmpeg thread-safety issues). + multimodal_config: MultimodalConfig for multimodal processing """ # Initialize base class super().__init__( @@ -534,6 +783,41 @@ def __init__( self._filter_count = 0 self._total_processed = 0 + # Prefetch buffer for overlapping multimodal I/O with compute. + # Only enabled when multimodal_keys are set and prefetch_max_cached > 0. + self._prefetch_buffer: Optional[PrefetchBuffer] = None + if multimodal_keys and prefetch_max_cached > 0: + self._prefetch_buffer = PrefetchBuffer( + process_fn=self._prefetch_process_single, + chunk_size=prefetch_chunk_size, + max_cached=prefetch_max_cached, + num_workers=prefetch_num_workers, + ) + logger.info( + f"StreamingDataset: prefetch enabled with " + f"chunk_size={prefetch_chunk_size}, max_cached={prefetch_max_cached}, " + f"num_workers={prefetch_num_workers}" + ) + self._prefetch_hits_log_counter = 0 + + def _prefetch_process_single(self, idx: int) -> Optional[Sample]: + """Process a single index for prefetching. + + Called by the PrefetchBuffer's worker threads. Each call loads + a single sample (including heavy video/image I/O) and returns it. + + NOTE: We intentionally do NOT access ``self.buffer`` (SampleBuffer) + here because SampleBuffer is not thread-safe and these calls run + in parallel worker threads. + """ + try: + raw_data = self.reader[idx] + sample = self._process_raw_data(raw_data) + return sample + except Exception as e: + logger.warning(f"Prefetch: error processing index {idx}: {e}") + return None + def __len__(self) -> int: """Return total number of samples in the dataset.""" return len(self.reader) @@ -541,11 +825,24 @@ def __len__(self) -> int: def shuffle(self, epoch_id: int) -> None: """Shuffle the dataset for a new epoch. + When prefetch is enabled, passes the **remaining** shuffled index + sequence (from the current position onward) to the + ``PrefetchBuffer`` so the background thread starts loading + immediately — well before ``get_batch`` is called. + Args: epoch_id: Epoch identifier """ self.index_manager.shuffle(epoch_id) self.epoch_id = epoch_id + # Trigger prefetch with the remaining upcoming index order + if self._prefetch_buffer is not None and self.index_manager.indices is not None: + remaining = self.index_manager.indices[self.index_manager.position :] + self._prefetch_buffer.set_index_order(list(remaining)) + logger.info( + f"Prefetch: triggered for epoch {epoch_id}, " + f"{len(remaining)} indices remaining (position={self.index_manager.position})" + ) def _process_raw_data(self, data: dict) -> Optional[Sample]: """Process raw data into a Sample. @@ -580,40 +877,97 @@ def get_batch(self, n: int) -> tuple[list[Sample], bool]: Automatically skips filtered samples and handles epoch boundaries. + When prefetch is enabled, indices are consumed **one at a time** so + that ``IndexManager.position`` stays exactly in sync with the index + sequence given to ``PrefetchBuffer.set_index_order()``. + + Without prefetch, indices are fetched in small batches for + efficiency (acceptable since there is no ordering contract to + honour with a background thread). + Args: n: Number of samples to get Returns: (samples, crossed_epoch): List of samples and whether an epoch boundary was crossed """ - samples = [] + if self._prefetch_buffer is not None: + return self._get_batch_prefetch(n) + return self._get_batch_no_prefetch(n) + + def _get_batch_prefetch(self, n: int) -> tuple[list[Sample], bool]: + """Prefetch-aware path: consume indices one-by-one to stay aligned.""" + samples: list[Sample] = [] crossed_epoch = False - max_attempts = n * 10 # Prevent infinite loop if too many filtered + max_attempts = n * 10 + + for _ in range(max_attempts): + if len(samples) >= n: + break + + indices, epoch_crossed = self.index_manager.get_next_indices(1) + if epoch_crossed and not crossed_epoch: + crossed_epoch = True + # IndexManager already shuffled the new epoch internally. + # Re-trigger prefetch immediately for the remaining indices + # so subsequent get() calls hit the cache instead of falling + # back to synchronous loading. + remaining = self.index_manager.indices[self.index_manager.position :] + self._prefetch_buffer.set_index_order(list(remaining)) + logger.info( + f"Prefetch: epoch crossing detected, re-triggered with " + f"{len(remaining)} indices (epoch={self.index_manager.current_epoch})" + ) + idx = indices[0] + + sample = self._prefetch_buffer.get(idx) + + if sample is None: + # Prefetch returned None — either the sample was filtered + # out during prefetch or prefetch failed; skip it. + continue + + samples.append(sample) + + if len(samples) < n: + logger.warning( + f"Could only get {len(samples)}/{n} samples after {max_attempts} attempts. " + f"Filter rate: {self._filter_count}/{self._total_processed}" + ) + + if self._prefetch_hits_log_counter % 10 == 0: + logger.info( + f"Prefetch stats: hit_rate={self._prefetch_buffer.hit_rate:.1%}, " + f"cache_size={self._prefetch_buffer.cache_size}" + ) + self._prefetch_hits_log_counter += 1 + + return samples, crossed_epoch + + def _get_batch_no_prefetch(self, n: int) -> tuple[list[Sample], bool]: + """Non-prefetch path: fetch indices in small batches for efficiency.""" + samples: list[Sample] = [] + crossed_epoch = False + max_attempts = n * 10 attempts = 0 while len(samples) < n and attempts < max_attempts: - # Calculate how many more we need (with some buffer for filtered samples) need = n - len(samples) - fetch_size = min(need * 2, 100) # Fetch extra to account for filtering + fetch_size = min(need * 2, 100) indices, epoch_crossed = self.index_manager.get_next_indices(fetch_size) crossed_epoch = crossed_epoch or epoch_crossed for idx in indices: if len(samples) >= n: - # Put back unused indices break attempts += 1 - # Check cache first sample = self.buffer.get(idx) - if sample is None: - # Load and process raw_data = self.reader[idx] sample = self._process_raw_data(raw_data) - if sample is not None: self.buffer.put(idx, sample) diff --git a/relax/utils/device.py b/relax/utils/device.py new file mode 100644 index 00000000..b53c3490 --- /dev/null +++ b/relax/utils/device.py @@ -0,0 +1,446 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. +# +# Multi-hardware backend abstraction layer. +# +# Inspired by verl (https://github.com/verl-project/verl) device.py +# and slime (https://github.com/THUDM/slime) plugin architecture. +# +# This module provides a unified device abstraction that allows Relax to run +# on multiple hardware backends (NVIDIA CUDA, Ascend NPU, AMD ROCm, Kunlunxin XPU, +# PPU, etc.) with minimal code changes throughout the framework. +# +# Usage: +# from relax.utils.device import get_device_name, get_torch_device, ... +# +# The module auto-detects the available accelerator at import time and exposes +# a consistent API regardless of the underlying hardware. + +import os +from enum import Enum +from functools import lru_cache +from typing import Optional + +import torch + +from relax.utils.logging_utils import get_logger + + +logger = get_logger(__name__) + + +# --------------------------------------------------------------------------- +# Accelerator type enum +# --------------------------------------------------------------------------- +class AcceleratorType(str, Enum): + """Supported hardware accelerator types.""" + + CUDA = "cuda" # NVIDIA GPU + NPU = "npu" # Ascend NPU (Huawei) + XPU = "xpu" # Intel / Kunlunxin XPU + PPU = "ppu" # PPU (Enflame / custom) + ROCM = "rocm" # AMD ROCm (uses 'cuda' device in PyTorch but HIP backend) + CPU = "cpu" # CPU fallback + + +# --------------------------------------------------------------------------- +# Detection helpers (cached — hardware won't change at runtime) +# --------------------------------------------------------------------------- +@lru_cache(maxsize=1) +def _detect_accelerator() -> AcceleratorType: + """Detect the available hardware accelerator. + + Detection order follows specificity: NPU > XPU > PPU > CUDA/ROCm > CPU. + Environment variable ``RELAX_DEVICE_TYPE`` can override auto-detection. + """ + # Allow explicit override via environment variable + override = os.environ.get("RELAX_DEVICE_TYPE", "").lower().strip() + if override: + for accel in AcceleratorType: + if override == accel.value: + logger.info(f"Device type overridden by RELAX_DEVICE_TYPE={override}") + return accel + logger.warning(f"Unknown RELAX_DEVICE_TYPE='{override}', falling back to auto-detection") + + # Ascend NPU + if _is_npu_available(): + return AcceleratorType.NPU + + # Kunlunxin / Intel XPU + if _is_xpu_available(): + return AcceleratorType.XPU + + # PPU (Enflame) + if _is_ppu_available(): + return AcceleratorType.PPU + + # NVIDIA CUDA or AMD ROCm (both expose torch.cuda) + if torch.cuda.is_available(): + if _is_rocm(): + return AcceleratorType.ROCM + return AcceleratorType.CUDA + + return AcceleratorType.CPU + + +def _is_npu_available() -> bool: + """Check if Ascend NPU is available.""" + try: + if not hasattr(torch, "npu"): + return False + return torch.npu.is_available() + except (ImportError, AttributeError): + return False + + +def _is_xpu_available() -> bool: + """Check if XPU (Intel / Kunlunxin) is available.""" + try: + if not hasattr(torch, "xpu"): + return False + return torch.xpu.is_available() + except (ImportError, AttributeError): + return False + + +def _is_ppu_available() -> bool: + """Check if PPU is available.""" + try: + if not hasattr(torch, "ppu"): + return False + return torch.ppu.is_available() + except (ImportError, AttributeError): + return False + + +def _is_rocm() -> bool: + """Check if the current CUDA build is actually AMD ROCm/HIP.""" + return getattr(torch.version, "hip", None) is not None + + +# --------------------------------------------------------------------------- +# Public API — device info +# --------------------------------------------------------------------------- +def get_accelerator_type() -> AcceleratorType: + """Return the detected :class:`AcceleratorType`.""" + return _detect_accelerator() + + +def get_device_name() -> str: + """Return the PyTorch device type string (``'cuda'``, ``'npu'``, ``'xpu'``, + etc.). + + For ROCm, returns ``'cuda'`` because PyTorch ROCm uses the CUDA device + namespace. + """ + accel = _detect_accelerator() + if accel == AcceleratorType.ROCM: + return "cuda" # ROCm uses torch.cuda namespace + if accel == AcceleratorType.CPU: + return "cpu" + return accel.value + + +def get_torch_device_module(): + """Return the ``torch.`` module (e.g. ``torch.cuda``, + ``torch.npu``). + + This is the namespace that provides ``current_device()``, + ``synchronize()``, ``empty_cache()``, etc. + """ + name = get_device_name() + try: + return getattr(torch, name) + except AttributeError: + logger.warning(f"torch.{name} not found, falling back to torch.cuda") + return torch.cuda + + +# --------------------------------------------------------------------------- +# Public API — distributed backend +# --------------------------------------------------------------------------- + +# Mapping from accelerator type to the default collective communication backend +_DIST_BACKEND_MAP = { + AcceleratorType.CUDA: "nccl", + AcceleratorType.ROCM: "nccl", # ROCm uses RCCL which is NCCL-compatible + AcceleratorType.NPU: "hccl", + AcceleratorType.XPU: "xccl", + AcceleratorType.PPU: "eccl", + AcceleratorType.CPU: "gloo", +} + + +def get_dist_backend() -> str: + """Return the default distributed communication backend name. + + Returns ``'nccl'`` for NVIDIA/AMD, ``'hccl'`` for Ascend NPU, etc. + """ + return _DIST_BACKEND_MAP.get(_detect_accelerator(), "nccl") + + +# --------------------------------------------------------------------------- +# Public API — environment variables +# --------------------------------------------------------------------------- + +# Mapping from accelerator type to the visible-devices environment variable +_VISIBLE_DEVICES_ENV_MAP = { + AcceleratorType.CUDA: "CUDA_VISIBLE_DEVICES", + AcceleratorType.ROCM: "CUDA_VISIBLE_DEVICES", # ROCm also uses this (or HIP_VISIBLE_DEVICES) + AcceleratorType.NPU: "ASCEND_RT_VISIBLE_DEVICES", + AcceleratorType.XPU: "XPU_VISIBLE_DEVICES", + AcceleratorType.PPU: "PPU_VISIBLE_DEVICES", + AcceleratorType.CPU: "", +} + + +def get_visible_devices_env_var() -> str: + """Return the environment variable name for controlling visible devices. + + E.g. ``'CUDA_VISIBLE_DEVICES'`` for NVIDIA, ``'ASCEND_RT_VISIBLE_DEVICES'`` + for Ascend NPU. + """ + return _VISIBLE_DEVICES_ENV_MAP.get(_detect_accelerator(), "CUDA_VISIBLE_DEVICES") + + +def get_visible_devices() -> Optional[str]: + """Return the value of the visible-devices environment variable, or + None.""" + env_var = get_visible_devices_env_var() + if not env_var: + return None + return os.environ.get(env_var) + + +# --------------------------------------------------------------------------- +# Public API — Ray resource name +# --------------------------------------------------------------------------- + +_RAY_RESOURCE_MAP = { + AcceleratorType.CUDA: "GPU", + AcceleratorType.ROCM: "GPU", + AcceleratorType.NPU: "NPU", + AcceleratorType.XPU: "XPU", + AcceleratorType.PPU: "PPU", + AcceleratorType.CPU: "CPU", +} + + +def get_ray_accelerator_name() -> str: + """Return the Ray resource name for the current accelerator. + + E.g. ``'GPU'`` for NVIDIA/AMD, ``'NPU'`` for Ascend. + """ + return _RAY_RESOURCE_MAP.get(_detect_accelerator(), "GPU") + + +# --------------------------------------------------------------------------- +# Public API — device operations (thin wrappers) +# --------------------------------------------------------------------------- +def current_device() -> int: + """Return the index of the current device.""" + mod = get_torch_device_module() + return mod.current_device() + + +def set_device(device) -> None: + """Set the current device. + + Args: + device: Device index (int) or device string (e.g. ``'cuda:0'``). + """ + mod = get_torch_device_module() + mod.set_device(device) + + +def device_count() -> int: + """Return the number of available accelerator devices.""" + mod = get_torch_device_module() + return mod.device_count() + + +def synchronize(device=None) -> None: + """Synchronize the current (or specified) device.""" + accel = _detect_accelerator() + if accel == AcceleratorType.CPU: + return # no-op for CPU + mod = get_torch_device_module() + if device is not None: + mod.synchronize(device) + else: + mod.synchronize() + + +def empty_cache() -> None: + """Release all unoccupied cached memory.""" + accel = _detect_accelerator() + if accel == AcceleratorType.CPU: + return + mod = get_torch_device_module() + mod.empty_cache() + + +def memory_allocated(device=None) -> int: + """Return the current GPU memory occupied by tensors in bytes.""" + mod = get_torch_device_module() + if device is not None: + return mod.memory_allocated(device) + return mod.memory_allocated() + + +def memory_reserved(device=None) -> int: + """Return the current GPU memory managed by the caching allocator in + bytes.""" + mod = get_torch_device_module() + if device is not None: + return mod.memory_reserved(device) + return mod.memory_reserved() + + +def mem_get_info(device=None): + """Return ``(free, total)`` memory in bytes for the given device.""" + mod = get_torch_device_module() + if device is not None: + return mod.mem_get_info(device) + return mod.mem_get_info() + + +def get_device_properties(device=None): + """Return device properties for the given device.""" + mod = get_torch_device_module() + if device is not None: + return mod.get_device_properties(device) + return mod.get_device_properties(mod.current_device()) + + +def current_stream(device=None): + """Return the currently selected stream for the given device.""" + mod = get_torch_device_module() + if device is not None: + return mod.current_stream(device) + return mod.current_stream() + + +def Stream(device=None, **kwargs): + """Create a new stream on the given device.""" + mod = get_torch_device_module() + if device is not None: + return mod.Stream(device=device, **kwargs) + return mod.Stream(**kwargs) + + +def Event(**kwargs): + """Create a new event.""" + mod = get_torch_device_module() + return mod.Event(**kwargs) + + +def stream_context(stream): + """Return a context manager that sets the given stream as the current + stream. + + Equivalent to ``torch.cuda.stream(s)`` but dispatches to the correct device + backend (e.g. ``torch.npu.stream(s)`` on Ascend NPU). + """ + mod = get_torch_device_module() + return mod.stream(stream) + + +def is_initialized() -> bool: + """Return True if the device backend has been initialized. + + Equivalent to ``torch.cuda.is_initialized()`` but dispatches to the correct + device backend. + """ + mod = get_torch_device_module() + if hasattr(mod, "is_initialized"): + return mod.is_initialized() + # Fallback: if the backend doesn't expose is_initialized, check if + # any device is available (conservative — assumes initialized if available). + return is_available() + + +# --------------------------------------------------------------------------- +# Public API — device string helpers +# --------------------------------------------------------------------------- +def make_device_string(index: Optional[int] = None) -> str: + """Build a device string like ``'cuda:0'`` or ``'npu:2'``. + + Args: + index: Device index. If None, uses :func:`current_device`. + """ + name = get_device_name() + if name == "cpu": + return "cpu" + if index is None: + index = current_device() + return f"{name}:{index}" + + +def make_current_torch_device() -> torch.device: + """Return a ``torch.device`` for the current accelerator and device + index.""" + return torch.device(make_device_string()) + + +# --------------------------------------------------------------------------- +# Public API — NUMA affinity +# --------------------------------------------------------------------------- +def set_numa_affinity(local_rank: int) -> None: + """Set NUMA affinity for the given local rank. + + On NVIDIA GPUs, uses pynvml. On other backends, this is a no-op with a + warning. + """ + accel = _detect_accelerator() + if accel in (AcceleratorType.CUDA,): + try: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank) + pynvml.nvmlDeviceSetCpuAffinity(handle) + logger.info(f"Set NUMA affinity for GPU {local_rank}") + pynvml.nvmlShutdown() + except ImportError: + logger.info("pynvml not available, skipping NUMA affinity setup") + except Exception as e: + logger.info(f"Failed to set NUMA affinity: {e}") + elif accel == AcceleratorType.ROCM: + logger.info("ROCm/HIP environment detected, skipping NUMA affinity setup") + elif accel == AcceleratorType.NPU: + logger.info("Ascend NPU environment, skipping NUMA affinity setup (not yet supported)") + else: + logger.info(f"NUMA affinity not supported for {accel.value}, skipping") + + +# --------------------------------------------------------------------------- +# Public API — expandable segments (CUDA-specific, no-op on others) +# --------------------------------------------------------------------------- +def set_expandable_segments(enable: bool) -> None: + """Configure CUDA memory allocator expandable segments. + + Only effective on NVIDIA CUDA. No-op on other backends. + """ + if _detect_accelerator() == AcceleratorType.CUDA: + try: + torch.cuda.memory._set_allocator_settings(f"expandable_segments:{enable}") + except Exception as e: + logger.warning(f"Failed to set expandable_segments: {e}") + + +# --------------------------------------------------------------------------- +# Public API — availability check +# --------------------------------------------------------------------------- +def is_available() -> bool: + """Return True if any accelerator device is available (not CPU-only).""" + return _detect_accelerator() != AcceleratorType.CPU + + +# --------------------------------------------------------------------------- +# Convenience: boolean flags (for backward compatibility / quick checks) +# --------------------------------------------------------------------------- +is_cuda_available: bool = torch.cuda.is_available() +is_npu_available: bool = _is_npu_available() +is_xpu_available: bool = _is_xpu_available() +is_ppu_available: bool = _is_ppu_available() +is_rocm: bool = _is_rocm() diff --git a/relax/utils/logging_utils.py b/relax/utils/logging_utils.py index 3195d2d9..aa354966 100644 --- a/relax/utils/logging_utils.py +++ b/relax/utils/logging_utils.py @@ -107,6 +107,10 @@ def configure_logger(prefix: str = "") -> None: handler.setLevel(LOG_LEVEL) handler.setFormatter(get_formatter(prefix)) root_logger.addHandler(handler) + + # Silence noisy third-party DEBUG loggers (PIL dumps PNG chunk metadata per image) + for noisy in ("PIL",): + logging.getLogger(noisy).setLevel(logging.WARNING) except Exception: # Silently ignore configuration errors to prevent breaking the application pass diff --git a/relax/utils/memory_utils.py b/relax/utils/memory_utils.py index 1c73176b..86112881 100644 --- a/relax/utils/memory_utils.py +++ b/relax/utils/memory_utils.py @@ -5,6 +5,7 @@ import torch import torch.distributed as dist +from relax.utils import device as device_utils from relax.utils.logging_utils import get_logger @@ -12,23 +13,23 @@ def clear_memory(clear_host_memory: bool = False): - torch.cuda.synchronize() + device_utils.synchronize() gc.collect() - torch.cuda.empty_cache() + device_utils.empty_cache() if clear_host_memory: torch._C._host_emptyCache() def available_memory(): - device = torch.cuda.current_device() - free, total = torch.cuda.mem_get_info(device) + dev = device_utils.current_device() + free, total = device_utils.mem_get_info(dev) return { - "gpu": str(device), + "device": str(dev), "total_GB": _byte_to_gb(total), "free_GB": _byte_to_gb(free), "used_GB": _byte_to_gb(total - free), - "allocated_GB": _byte_to_gb(torch.cuda.memory_allocated(device)), - "reserved_GB": _byte_to_gb(torch.cuda.memory_reserved(device)), + "allocated_GB": _byte_to_gb(device_utils.memory_allocated(dev)), + "reserved_GB": _byte_to_gb(device_utils.memory_reserved(dev)), } diff --git a/relax/utils/metrics/metric_utils.py b/relax/utils/metrics/metric_utils.py index 60cc131e..42be887f 100644 --- a/relax/utils/metrics/metric_utils.py +++ b/relax/utils/metrics/metric_utils.py @@ -4,6 +4,8 @@ import numpy as np +from relax.utils.types import Sample + logger = logging.getLogger(__name__) @@ -65,6 +67,48 @@ def compute_statistics(values: list[float]) -> dict[str, float]: } +def is_rollout_numeric_metric_value(value) -> bool: + return isinstance(value, (int, float, np.integer, np.floating)) + + +def append_rollout_numeric_metric_values(metric_values: dict[str, list[float]], *, key: str, value) -> None: + if isinstance(value, (list, tuple)): + flattened = [float(item) for item in value if is_rollout_numeric_metric_value(item)] + if flattened: + metric_values.setdefault(key, []).extend(flattened) + return + if is_rollout_numeric_metric_value(value): + metric_values.setdefault(key, []).append(float(value)) + + +def finalize_rollout_explicit_metric_values(metric_values: dict[str, list[float]]) -> dict[str, float]: + log_dict: dict[str, float] = {} + for metric_name, values in metric_values.items(): + if values: + log_dict |= dict_add_prefix(compute_statistics(values), f"{metric_name}/") + return log_dict + + +def compute_rollout_explicit_reward_metrics(args, samples: list[Sample]) -> dict[str, float]: + reward_metric_values: dict[str, list[float]] = {} + primary_reward_key = getattr(args, "reward_key", None) + for sample in samples: + reward = sample.reward + if not isinstance(reward, dict): + continue + for key, value in reward.items(): + if ( + not isinstance(key, str) + or not key + or key == primary_reward_key + or key == "raw_reward" + or key.startswith("_") + ): + continue + append_rollout_numeric_metric_values(reward_metric_values, key=key, value=value) + return finalize_rollout_explicit_metric_values(reward_metric_values) + + def compression_ratio( data: str | bytes, *, diff --git a/relax/utils/misc.py b/relax/utils/misc.py index 0aaa5ed3..089326f7 100644 --- a/relax/utils/misc.py +++ b/relax/utils/misc.py @@ -1,7 +1,9 @@ # Copyright (c) 2026 Relax Authors. All Rights Reserved. import importlib +import random import subprocess +import time from argparse import Namespace from collections import defaultdict from collections.abc import Callable, Iterable @@ -82,10 +84,21 @@ def get_current_node_ip(): def get_free_port(start_port=10000, consecutive=1): # find the port where port, port + 1, port + 2, ... port + consecutive - 1 are all available - port = start_port - while not all(is_port_available(port + i) for i in range(consecutive)): - port += 1 - return port + search_window = 1000 + max_start_port = min(65535 - consecutive + 1, start_port + search_window - consecutive) + if consecutive < 1 or start_port > max_start_port: + raise ValueError(f"Invalid port search range: {start_port=}, {consecutive=}") + + rng = random.Random(time.time_ns()) + ports = list(range(start_port, max_start_port + 1)) + rng.shuffle(ports) + for port in ports: + if all(is_port_available(port + i) for i in range(consecutive)): + return port + raise RuntimeError( + f"No free port available in [{start_port}, {start_port + search_window - 1}] " + f"with {consecutive} consecutive ports" + ) def should_run_periodic_action( diff --git a/relax/utils/profile_utils.py b/relax/utils/profile_utils.py index 72ece272..c3a15f71 100644 --- a/relax/utils/profile_utils.py +++ b/relax/utils/profile_utils.py @@ -1,11 +1,14 @@ # Copyright (c) 2026 Relax Authors. All Rights Reserved. +import asyncio +import os import time import traceback from pathlib import Path import torch +from relax.utils import device as device_utils from relax.utils.logging_utils import get_logger from relax.utils.memory_utils import print_memory @@ -13,6 +16,18 @@ logger = get_logger(__name__) +def _get_rank_tag() -> str: + """Build a rank tag string like ``rank0_dp0_tp0_pp0`` from Megatron mpu.""" + global_rank = torch.distributed.get_rank() + from megatron.core import mpu + + dp = mpu.get_data_parallel_rank(with_context_parallel=True) + tp = mpu.get_tensor_model_parallel_rank() + pp = mpu.get_pipeline_model_parallel_rank() + + return f"rank{global_rank}_dp{dp}_tp{tp}_pp{pp}" + + class TrainProfiler: def __init__(self, args): self.args = args @@ -21,7 +36,7 @@ def __init__(self, args): if args.use_pytorch_profiler and ("train_overall" in args.profile_target): self._torch_profiler_overall = _create_torch_profiler(args, name="train_overall") - logger.info(f"PyTorch profiler for overall training is enabled, dump dir: {args.tensorboard_dir}") + logger.info(f"PyTorch profiler for overall training is enabled, dump dir: {_get_train_trace_dir(args)}") if args.record_memory_history and ("train_overall" in args.profile_target): self._memory_profiler_overall = _BaseMemoryProfiler.create(args) @@ -62,18 +77,42 @@ def _profile_simple_loop(iterator, args, name): torch_profiler.step() +def _get_trace_base_dir(args): + """Return the base directory for all profiler outputs. + + Uses ``./traces/`` as the default location. Falls back + to ``./traces/`` when ``--tb-experiment-name`` is not set. + """ + task_name = getattr(args, "tb_experiment_name", None) + if task_name is None: + from datetime import datetime + + task_name = datetime.now().strftime("%Y%m%d_%H%M%S") + return os.path.join("traces", task_name) + + +def _get_train_trace_dir(args): + """Return the output directory for training profiler traces. + + Uses ``./traces//train_trace`` as the default location. + """ + return os.path.join(_get_trace_base_dir(args), "train_trace") + + def _create_torch_profiler(args, name): + trace_dir = _get_train_trace_dir(args) + worker_name = f"{name}_{_get_rank_tag()}" return torch.profiler.profile( schedule=torch.profiler.schedule( # TODO the train_actor and train_log_probs ones may need to have different args to control step wait=max(args.profile_step_start - 1, 0), warmup=1 if args.profile_step_start > 0 else 0, - active=args.profile_step_end - args.profile_step_start, + active=args.profile_step_end - args.profile_step_start + 1, # end is inclusive repeat=1, ), on_trace_ready=torch.profiler.tensorboard_trace_handler( - args.tensorboard_dir, - worker_name=f"{name}_rank_{torch.distributed.get_rank()}", + trace_dir, + worker_name=worker_name, use_gzip=True, ), record_shapes=True, @@ -93,9 +132,13 @@ def create(args): return c(args) def __init__(self, args): + snapshot_dir = getattr(args, "memory_snapshot_dir", None) + if snapshot_dir is None: + snapshot_dir = os.path.join(_get_trace_base_dir(args), "memory_snapshot") + os.makedirs(snapshot_dir, exist_ok=True) + rank_tag = _get_rank_tag() self._path_dump = ( - Path(args.memory_snapshot_dir) - / f"memory_snapshot_time{time.time()}_rank{torch.distributed.get_rank()}_{args.memory_snapshot_path}" + Path(snapshot_dir) / f"memory_snapshot_time{time.time()}_{rank_tag}_{args.memory_snapshot_path}" ) def start(self): @@ -109,7 +152,16 @@ class _TorchMemoryProfiler(_BaseMemoryProfiler): def start(self): logger.info("Attach OOM dump memory history.") - torch.cuda.memory._record_memory_history( + # Memory snapshot APIs are currently CUDA-specific. + # On non-CUDA backends, log a warning and skip. + device_mod = device_utils.get_torch_device_module() + if not hasattr(device_mod, "memory"): + logger.warning( + f"Memory snapshot profiling is not supported on {device_utils.get_device_name()} backend, skipping." + ) + return + + device_mod.memory._record_memory_history( max_entries=1000000, # record stack information for the trace events # trace_alloc_record_context=True, @@ -121,15 +173,22 @@ def oom_observer(device, alloc, device_alloc, device_free): f"Observe OOM, will dump snapshot to {self._path_dump}. ({device=} {alloc=} {device_alloc=} {device_free=}; stacktrace is as follows)" ) traceback.print_stack() - torch.cuda.memory._dump_snapshot(self._path_dump) + device_mod.memory._dump_snapshot(self._path_dump) print_memory("when oom") - torch._C._cuda_attach_out_of_memory_observer(oom_observer) + if hasattr(torch._C, "_cuda_attach_out_of_memory_observer"): + torch._C._cuda_attach_out_of_memory_observer(oom_observer) def stop(self): logger.info(f"Dump memory snapshot to: {self._path_dump}") - torch.cuda.memory._dump_snapshot(self._path_dump) - torch.cuda.memory._record_memory_history(enabled=None) + device_mod = device_utils.get_torch_device_module() + if not hasattr(device_mod, "memory"): + logger.warning( + f"Memory snapshot profiling is not supported on {device_utils.get_device_name()} backend, skipping." + ) + return + device_mod.memory._dump_snapshot(self._path_dump) + device_mod.memory._record_memory_history(enabled=None) class _MemrayMemoryProfiler(_BaseMemoryProfiler): @@ -150,3 +209,141 @@ def start(self): def stop(self): logger.info(f"Memray tracker stopped and dump snapshot to: {self._path_dump}") self._tracker.__exit__(None, None, None) + + +# --------------------------------------------------------------------------- +# SGLang profiling orchestration +# +# These helpers coordinate profiling across all SGLang engines by discovering +# worker URLs from the router and issuing HTTP start/stop requests. +# --------------------------------------------------------------------------- + + +def _get_sglang_trace_dir(args) -> str: + """Return the base output directory for SGLang profiler traces. + + Uses the user-specified ``--sglang-profile-output-dir`` if set, otherwise + falls back to ``./traces//sglang_trace``. + """ + base_dir = getattr(args, "sglang_profile_output_dir", None) + if base_dir is None: + base_dir = os.path.join(_get_trace_base_dir(args), "sglang_trace") + return base_dir + + +def _should_profile_sglang(args, rollout_id: int) -> bool: + """Determine whether SGLang profiling should be active for the given + rollout step. + + Resolution order: + 1. ``--sglang-profile`` must be enabled (master switch). + 2. ``--sglang-profile-steps`` (explicit list) takes precedence if set. + 3. ``--sglang-profile-step-start`` / ``--sglang-profile-step-end`` (range) + is checked next. Both bounds are *inclusive* and use absolute rollout IDs. + 4. If neither is set, every step is profiled. + """ + if not getattr(args, "sglang_profile", False): + return False + + profile_steps = getattr(args, "sglang_profile_steps", None) + if profile_steps is not None: + return rollout_id in profile_steps + + step_start = getattr(args, "sglang_profile_step_start", None) + step_end = getattr(args, "sglang_profile_step_end", None) + if step_start is not None or step_end is not None: + lo = step_start if step_start is not None else 0 + hi = step_end if step_end is not None else float("inf") + return lo <= rollout_id <= hi + + # No filter specified — profile every step. + return True + + +async def _get_sglang_worker_urls(args) -> list[str]: + """Discover SGLang worker URLs from the router.""" + import sglang_router + from packaging.version import parse + + from relax.utils.http_utils import get + + if parse(sglang_router.__version__) <= parse("0.2.1") or getattr(args, "use_slime_router", False): + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") + return response["urls"] + else: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") + return [worker["url"] for worker in response["workers"]] + + +async def start_sglang_profile(args, rollout_id: int) -> None: + """Start torch profiling on all SGLang engines if ``--sglang-profile`` is + enabled. + + Profile traces are organized as:: + + traces//sglang_trace/rollout_/ + + When ``--sglang-profile-output-dir`` is explicitly set, that path is used + as the base instead. + """ + if not _should_profile_sglang(args, rollout_id): + return + + from relax.utils.http_utils import post + + # Build per-step output directory: /rollout_ + base_dir = _get_sglang_trace_dir(args) + step_dir = os.path.join(base_dir, f"rollout_{rollout_id}") + os.makedirs(step_dir, exist_ok=True) + + num_steps = getattr(args, "sglang_profile_num_steps", None) + if num_steps is not None and num_steps < 0: + num_steps = None + + urls = await _get_sglang_worker_urls(args) + base_payload = { + "output_dir": step_dir, + "num_steps": num_steps, + "activities": getattr(args, "sglang_profile_activities", None), + "profile_by_stage": getattr(args, "sglang_profile_by_stage", False), + "with_stack": getattr(args, "sglang_profile_with_stack", False), + "record_shapes": getattr(args, "sglang_profile_record_shapes", False), + } + + logger.info( + f"Starting SGLang profiling on {len(urls)} engines for rollout step {rollout_id}, " + f"output_dir={step_dir}, num_steps={num_steps}" + ) + tasks = [] + for i, url in enumerate(urls): + payload = {**base_payload, "profile_prefix": f"engine{i}"} + tasks.append(post(f"{url}/start_profile", payload)) + results = await asyncio.gather(*tasks, return_exceptions=True) + for url, result in zip(urls, results, strict=False): + if isinstance(result, BaseException): + logger.warning(f"Failed to start profile on {url}: {result}") + else: + logger.info(f"Started profiling on {url}") + + +async def stop_sglang_profile(args, rollout_id: int) -> None: + """Stop torch profiling on all SGLang engines if ``--sglang-profile`` is + enabled.""" + if not _should_profile_sglang(args, rollout_id): + return + + # If num_steps was set, SGLang auto-stops — skip explicit stop. + if getattr(args, "sglang_profile_num_steps", -1) > 0: + return + + from relax.utils.http_utils import post + + urls = await _get_sglang_worker_urls(args) + logger.info(f"Stopping SGLang profiling on {len(urls)} engines for rollout step {rollout_id}") + tasks = [post(f"{url}/stop_profile", {}) for url in urls] + results = await asyncio.gather(*tasks, return_exceptions=True) + for url, result in zip(urls, results, strict=False): + if isinstance(result, BaseException): + logger.warning(f"Failed to stop profile on {url}: {result}") + else: + logger.info(f"Stopped profiling on {url}") diff --git a/relax/utils/reloadable_process_group.py b/relax/utils/reloadable_process_group.py index 6016ea61..ebd2f2be 100644 --- a/relax/utils/reloadable_process_group.py +++ b/relax/utils/reloadable_process_group.py @@ -1,10 +1,17 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. + import os +import time +import socket +from collections.abc import Callable from contextlib import contextmanager from datetime import timedelta +from typing import Any import torch import torch.distributed as dist +from relax.utils import device as device_utils from relax.utils.logging_utils import get_logger from relax.utils.memory_utils import available_memory, clear_memory, print_memory @@ -68,10 +75,16 @@ def get_new_comm_function(func): """Wrap communication functions with memory check.""" def new_function(*args, **kwargs): - args = tuple([arg.group if isinstance(arg, ReloadableProcessGroup) else arg for arg in args]) + original_args = args + original_kwargs = kwargs + args = tuple(arg.group if isinstance(arg, ReloadableProcessGroup) else arg for arg in args) kwargs = {k: (v.group if isinstance(v, ReloadableProcessGroup) else v) for k, v in kwargs.items()} with _wrap_low_level_call(): - return func(*args, **kwargs) + try: + return func(*args, **kwargs) + except Exception as exc: + _log_distributed_exception(func.__name__, original_args, original_kwargs, exc) + raise return new_function @@ -149,13 +162,15 @@ def __getattr__(self, name): return getattr(self.group, name) @staticmethod - def destroy_process_groups(): + def destroy_process_groups(post_destroy_delay: float = 2.0): pid = os.getpid() + destroyed_count = 0 for reloadable_group in ReloadableProcessGroup.GROUPS.get(pid, []): if reloadable_group.group is None: continue try: dist.destroy_process_group(reloadable_group.group) + destroyed_count += 1 except ValueError as e: logger.warning( f"Process group already invalid/destroyed; skipping cleanup. Exception: {e}", @@ -165,21 +180,52 @@ def destroy_process_groups(): del reloadable_group.group reloadable_group.group = None + if destroyed_count > 0 and post_destroy_delay > 0: + # Wait for OS to release NCCL socket ports (TCP TIME_WAIT), + # preventing "Address already in use" on subsequent reload. + logger.info( + f"Destroyed {destroyed_count} process groups, waiting {post_destroy_delay}s " + "for NCCL socket port release" + ) + time.sleep(post_destroy_delay) + @staticmethod - def reload_process_groups(timeout_minutes: int = 30): + def reload_process_groups(timeout_minutes: int = 30, max_retries: int = 3, retry_delay: float = 5.0): pid = os.getpid() reloadable_groups = ReloadableProcessGroup.GROUPS.get(pid, []) logger.info(f"Reloading {len(reloadable_groups)} process groups in pid {pid}") old_new_group = old_new_group_dict.get(pid) - for reloadable_group in reloadable_groups: + for idx, reloadable_group in enumerate(reloadable_groups): if reloadable_group.group is not None: continue - group = old_new_group( - ranks=reloadable_group.group_info["ranks"], - backend="nccl", - timeout=timedelta(minutes=timeout_minutes), - ) - reloadable_group.group = group + last_error = None + for attempt in range(1, max_retries + 1): + try: + group = old_new_group( + ranks=reloadable_group.group_info["ranks"], + backend=device_utils.get_dist_backend(), + timeout=timedelta(minutes=timeout_minutes), + ) + reloadable_group.group = group + if attempt > 1: + logger.info(f"Process group {idx} reloaded successfully on attempt {attempt}") + last_error = None + break + except Exception as e: + last_error = e + logger.warning( + f"Failed to reload process group {idx} (attempt {attempt}/{max_retries}): {e}", + exc_info=(attempt == max_retries), + ) + if attempt < max_retries: + sleep_time = retry_delay * attempt + logger.info(f"Retrying in {sleep_time}s...") + time.sleep(sleep_time) + if last_error is not None: + raise RuntimeError( + f"Failed to reload process group {idx} after {max_retries} attempts " + f"(ranks={reloadable_group.group_info['ranks']})" + ) from last_error def rank(self) -> int: return self.group.rank() @@ -203,7 +249,11 @@ def _fwd(self, method, *args, **kwargs): if inner is None: raise RuntimeError("ReloadableProcessGroup: inner PG is None, call reload() first.") with _wrap_low_level_call(): - return getattr(inner, method)(*args, **kwargs) + try: + return getattr(inner, method)(*args, **kwargs) + except Exception as exc: + _log_distributed_exception(method, (self, *args), kwargs, exc) + raise def _fwd_query(self, method, *args, **kwargs): """Forward non-communication calls without memory check.""" @@ -293,14 +343,16 @@ def bound_device_id(self, dev): self.group.bound_device_id = dev -def destroy_process_groups(): +def destroy_process_groups(post_destroy_delay: float = 2.0): """Destroy all reloadable process groups.""" - ReloadableProcessGroup.destroy_process_groups() + ReloadableProcessGroup.destroy_process_groups(post_destroy_delay=post_destroy_delay) -def reload_process_groups(timeout_minutes: int = 30): +def reload_process_groups(timeout_minutes: int = 30, max_retries: int = 3, retry_delay: float = 5.0): """Reload all reloadable process groups.""" - ReloadableProcessGroup.reload_process_groups(timeout_minutes=timeout_minutes) + ReloadableProcessGroup.reload_process_groups( + timeout_minutes=timeout_minutes, max_retries=max_retries, retry_delay=retry_delay + ) @contextmanager diff --git a/relax/utils/training/routing_replay.py b/relax/utils/training/routing_replay.py index 096f4e74..95610316 100644 --- a/relax/utils/training/routing_replay.py +++ b/relax/utils/training/routing_replay.py @@ -2,6 +2,8 @@ import torch +from relax.utils import device as device_utils + ROUTING_REPLAY = None @@ -29,12 +31,12 @@ def record(self, top_indices): def pop_forward(self): top_indices = self.top_indices_list[self.forward_index] self.forward_index += 1 - return top_indices.to(torch.cuda.current_device()) + return top_indices.to(device_utils.make_current_torch_device()) def pop_backward(self): top_indices = self.top_indices_list[self.backward_index] self.backward_index += 1 - return top_indices.to(torch.cuda.current_device()) + return top_indices.to(device_utils.make_current_torch_device()) def clear(self): self.forward_index = 0 diff --git a/relax/utils/training/tensor_backper.py b/relax/utils/training/tensor_backper.py index 955975ab..e9a0145f 100644 --- a/relax/utils/training/tensor_backper.py +++ b/relax/utils/training/tensor_backper.py @@ -4,6 +4,8 @@ import torch +from relax.utils import device as device_utils + _SourceGetter = Callable[[], Iterable[tuple[str, torch.Tensor]]] @@ -59,7 +61,7 @@ def backup(self, tag: str) -> None: if name not in backup_dict: backup_dict[name] = torch.empty_like(param, device=torch.device("cpu"), pin_memory=True) backup_dict[name].copy_(param.detach(), non_blocking=True) - torch.cuda.synchronize() + device_utils.synchronize() @torch.no_grad() def copy(self, *, src_tag: str, dst_tag: str): @@ -72,7 +74,7 @@ def restore(self, tag: str) -> None: for name, param in self._source_getter(): assert name in backup_dict param.copy_(backup_dict[name], non_blocking=True) - torch.cuda.synchronize() + device_utils.synchronize() class _TensorBackuperNoop(TensorBackuper): @@ -95,12 +97,12 @@ def get(self, tag: str): def backup(self, tag: str) -> None: assert tag == self._single_tag self._backup_hash_dict = _compute_hash_dict(dict(self._source_getter())) - torch.cuda.synchronize() + device_utils.synchronize() def restore(self, tag: str) -> None: assert tag == self._single_tag assert _compute_hash_dict(dict(self._source_getter())) == self._backup_hash_dict - torch.cuda.synchronize() + device_utils.synchronize() def _compute_hash_dict(tensors: dict[str, torch.Tensor]): diff --git a/relax/utils/utils.py b/relax/utils/utils.py index b5cd0034..56ae986c 100644 --- a/relax/utils/utils.py +++ b/relax/utils/utils.py @@ -61,8 +61,12 @@ def convert_samples_to_train_data(args: Any, samples: list[Sample] | list[list[S train_data["loss_masks"] = loss_masks # overwriting the raw reward - if samples[0].metadata and "raw_reward" in samples[0].metadata: - train_data["raw_reward"] = [sample.metadata["raw_reward"] for sample in samples] + # populate this field for a subset of samples (e.g. SWE but not code). + if any(sample.metadata and "raw_reward" in sample.metadata for sample in samples): + train_data["raw_reward"] = [ + sample.metadata["raw_reward"] if sample.metadata and "raw_reward" in sample.metadata else sample.reward + for sample in samples + ] # For rollout buffer if samples[0].metadata and "round_number" in samples[0].metadata: @@ -78,7 +82,7 @@ def convert_samples_to_train_data(args: Any, samples: list[Sample] | list[list[S if samples[0].train_metadata is not None: train_data["metadata"] = [sample.train_metadata for sample in samples] - if samples[0].multimodal_train_inputs is not None: + if any(sample.multimodal_train_inputs is not None for sample in samples): train_data["multimodal_train_inputs"] = [sample.multimodal_train_inputs for sample in samples] if samples[0].teacher_log_probs is not None: diff --git a/scripts/models/llama3-8B.sh b/scripts/models/llama3-8B.sh deleted file mode 100644 index 8ca24fb1..00000000 --- a/scripts/models/llama3-8B.sh +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) 2026 Relax Authors. All Rights Reserved. - -MODEL_ARGS=( - --swiglu - --num-layers 32 - --hidden-size 4096 - --ffn-hidden-size 14336 - --num-attention-heads 32 - --group-query-attention - --num-query-groups 8 - --use-rotary-position-embeddings - --disable-bias-linear - --normalization "RMSNorm" - --norm-epsilon 1e-5 - --rotary-base "${MODEL_ARGS_ROTARY_BASE:-500000}" - --vocab-size 128256 - --kv-channels 128 - --untie-embeddings-and-output-weights - --seq-length "${MODEL_ARGS_SEQ_LENGTH:-8192}" -) diff --git a/scripts/models/qwen36-35B-A3B.sh b/scripts/models/qwen36-35B-A3B.sh new file mode 100644 index 00000000..84eb9046 --- /dev/null +++ b/scripts/models/qwen36-35B-A3B.sh @@ -0,0 +1,59 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. + +NLAYERS=40 +FIRST_K_DENSE_REPLACE=0 + +arr=() +for ((i=0; i", "", "", ""] -MODEL_VOCAB_SIZE = 256 -HIDDEN_SIZE = 128 -INTERMEDIATE_SIZE = 256 -NUM_HIDDEN_LAYERS = 2 -NUM_ATTENTION_HEADS = 4 -NUM_KEY_VALUE_HEADS = 2 -MAX_POSITION_EMBEDDINGS = 128 -PROMPT_LENGTH = 8 -RESPONSE_LENGTH = 4 -NUM_SAMPLES = 4 - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Create local tiny Qwen2 assets for Relax ROCm smoke tests.") - parser.add_argument("--output-dir", type=Path, required=True, help="Directory to write the generated assets into.") - parser.add_argument( - "--force", - action="store_true", - help="Overwrite existing assets in the output directory.", - ) - return parser.parse_args() - - -def build_vocab() -> dict[str, int]: - base_tokens = [f"tok_{i}" for i in range(MODEL_VOCAB_SIZE - len(SPECIAL_TOKENS))] - vocab_tokens = SPECIAL_TOKENS + base_tokens - return {token: idx for idx, token in enumerate(vocab_tokens)} - - -def create_tokenizer(tokenizer_dir: Path, *, force: bool) -> None: - if tokenizer_dir.exists() and not force: - return - - tokenizer_dir.mkdir(parents=True, exist_ok=True) - vocab = build_vocab() - - tokenizer = Tokenizer(WordLevel(vocab=vocab, unk_token="")) - tokenizer.pre_tokenizer = WhitespaceSplit() - - fast_tokenizer = PreTrainedTokenizerFast( - tokenizer_object=tokenizer, - unk_token="", - pad_token="", - bos_token="", - eos_token="", - ) - fast_tokenizer.model_max_length = MAX_POSITION_EMBEDDINGS - fast_tokenizer.save_pretrained(tokenizer_dir) - - generation_config = { - "bos_token_id": vocab[""], - "eos_token_id": vocab[""], - "pad_token_id": vocab[""], - "transformers_version": "auto", - } - (tokenizer_dir / "generation_config.json").write_text(json.dumps(generation_config, indent=2), encoding="utf-8") - - -def create_model(model_dir: Path, *, force: bool) -> None: - if model_dir.exists() and (model_dir / "config.json").exists() and not force: - return - - model_dir.mkdir(parents=True, exist_ok=True) - vocab = build_vocab() - config = Qwen2Config( - vocab_size=len(vocab), - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - num_hidden_layers=NUM_HIDDEN_LAYERS, - num_attention_heads=NUM_ATTENTION_HEADS, - num_key_value_heads=NUM_KEY_VALUE_HEADS, - max_position_embeddings=MAX_POSITION_EMBEDDINGS, - rms_norm_eps=1e-6, - rope_theta=10000.0, - tie_word_embeddings=False, - attention_bias=False, - bos_token_id=vocab[""], - eos_token_id=vocab[""], - pad_token_id=vocab[""], - torch_dtype="bfloat16", - ) - model = Qwen2ForCausalLM(config) - model.save_pretrained(model_dir, safe_serialization=True) - - -def build_sample(index: int) -> Sample: - prompt_tokens = [10 + index, 20 + index, 30 + index, 40 + index, 50 + index, 60 + index, 70 + index, 80 + index] - response_tokens = [100 + index, 110 + index, 120 + index, 130 + index] - tokens = prompt_tokens + response_tokens - reward = [1.0, 0.0, 0.8, -0.2][index] - - return Sample( - group_index=index // 2, - index=index, - prompt=f"prompt {index}", - response=f"response {index}", - tokens=tokens, - rollout_tokens=tokens.copy(), - response_length=RESPONSE_LENGTH, - reward=reward, - loss_mask=[1] * RESPONSE_LENGTH, - status=Sample.Status.COMPLETED, - label=f"label {index}", - metadata={"raw_reward": reward}, - train_metadata={"source": "tiny_qwen2_smoke"}, - ) - - -def create_debug_rollout(debug_dir: Path, *, force: bool) -> None: - target_path = debug_dir / "0.pt" - if target_path.exists() and not force: - return - - debug_dir.mkdir(parents=True, exist_ok=True) - samples = [build_sample(i).to_dict() for i in range(NUM_SAMPLES)] - torch.save({"rollout_id": 0, "samples": samples}, target_path) - - -def main() -> None: - args = parse_args() - output_dir = args.output_dir.resolve() - tokenizer_dir = output_dir / "hf_model" - debug_dir = output_dir / "debug_rollout" - - create_tokenizer(tokenizer_dir, force=args.force) - create_model(tokenizer_dir, force=args.force) - create_debug_rollout(debug_dir, force=args.force) - - manifest = { - "hf_model": str(tokenizer_dir), - "debug_rollout": str(debug_dir), - "num_samples": NUM_SAMPLES, - "prompt_length": PROMPT_LENGTH, - "response_length": RESPONSE_LENGTH, - "vocab_size": MODEL_VOCAB_SIZE, - } - (output_dir / "manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8") - print(json.dumps(manifest, indent=2)) - - -if __name__ == "__main__": - main() diff --git a/scripts/training/multimodal/run-qwen3-30B-A3B-omni-16xgpu-async.sh b/scripts/training/multimodal/run-qwen3-30B-A3B-omni-16xgpu-async.sh index 5940af28..c3a98abb 100644 --- a/scripts/training/multimodal/run-qwen3-30B-A3B-omni-16xgpu-async.sh +++ b/scripts/training/multimodal/run-qwen3-30B-A3B-omni-16xgpu-async.sh @@ -93,7 +93,7 @@ GRPO_ARGS=( --kl-loss-coef 0.001 --kl-loss-type low_var_kl --entropy-coef 0.00 - --eps-clip 3.0 + --eps-clip 0.2 --eps-clip-high 0.28 --use-tis ) diff --git a/scripts/training/multimodal/run-qwen3-30B-A3B-omni-16xgpu-video.sh b/scripts/training/multimodal/run-qwen3-30B-A3B-omni-16xgpu-video.sh index e67da4ac..019524b1 100644 --- a/scripts/training/multimodal/run-qwen3-30B-A3B-omni-16xgpu-video.sh +++ b/scripts/training/multimodal/run-qwen3-30B-A3B-omni-16xgpu-video.sh @@ -86,7 +86,7 @@ GRPO_ARGS=( --kl-loss-coef 0.001 --kl-loss-type low_var_kl --entropy-coef 0.00 - --eps-clip 3.0 + --eps-clip 0.2 --eps-clip-high 0.28 --use-tis ) diff --git a/scripts/training/multimodal/run-qwen3-30B-A3B-omni-16xgpu.sh b/scripts/training/multimodal/run-qwen3-30B-A3B-omni-16xgpu.sh index 68a5bd79..906469fe 100644 --- a/scripts/training/multimodal/run-qwen3-30B-A3B-omni-16xgpu.sh +++ b/scripts/training/multimodal/run-qwen3-30B-A3B-omni-16xgpu.sh @@ -50,6 +50,7 @@ ROLLOUT_ARGS=( # --rollout-max-prompt-len 2048 --rollout-temperature 0.8 --global-batch-size 512 + --use-streaming-dataset --balance-data --use-fault-tolerance --system-prompt "${SYSTEM_PROMPT}" @@ -88,7 +89,7 @@ GRPO_ARGS=( --kl-loss-coef 0.001 --kl-loss-type low_var_kl --entropy-coef 0.00 - --eps-clip 3.0 + --eps-clip 0.2 --eps-clip-high 0.28 --use-tis ) diff --git a/scripts/training/multimodal/run-qwen3-vl-4B-2xgpu.sh b/scripts/training/multimodal/run-qwen3-vl-4B-2xgpu.sh index 5a1c7b4d..5bd28d38 100644 --- a/scripts/training/multimodal/run-qwen3-vl-4B-2xgpu.sh +++ b/scripts/training/multimodal/run-qwen3-vl-4B-2xgpu.sh @@ -13,16 +13,9 @@ set -o pipefail now=$(date "+%Y-%m-%d-%H:%M:%S") echo "当前时间: $now" -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/../../entrypoint/device_env.sh" - -if [ -z "$(relax_visible_devices)" ]; then - SELECTED_GPUS="$(relax_select_top_gpus_by_free_mem 2)" - if [ -n "${SELECTED_GPUS}" ]; then - relax_export_visible_devices "${SELECTED_GPUS}" - fi -fi +export CUDA_VISIBLE_DEVICES=$(nvidia-smi --query-gpu=index,memory.free --format=csv,noheader,nounits | sort -t, -k2 -rn | head -n 2 | cut -d, -f1 | paste -sd ',') +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" # Auto-source local environment when not launched via an external entrypoint if [ -z "${RELAX_ENTRYPOINT_MODE:-}" ]; then source "${SCRIPT_DIR}/../../entrypoint/local.sh" @@ -100,6 +93,9 @@ OPTIMIZER_ARGS=( --adam-beta1 0.9 --adam-beta2 0.98 --clip-grad 1.0 + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer --no-rope-fusion ) diff --git a/scripts/training/multimodal/run-qwen3-vl-4B-8xgpu.sh b/scripts/training/multimodal/run-qwen3-vl-4B-8xgpu.sh index 4bf13f04..b0462bdd 100644 --- a/scripts/training/multimodal/run-qwen3-vl-4B-8xgpu.sh +++ b/scripts/training/multimodal/run-qwen3-vl-4B-8xgpu.sh @@ -53,10 +53,10 @@ ROLLOUT_ARGS=( ) PERF_ARGS=( - --tensor-model-parallel-size 4 + --tensor-model-parallel-size 2 --sequence-parallel --pipeline-model-parallel-size 1 - --context-parallel-size 1 + --context-parallel-size 4 --expert-model-parallel-size 1 --expert-tensor-parallel-size 1 @@ -64,7 +64,9 @@ PERF_ARGS=( --recompute-method uniform --recompute-num-layers 1 - #--micro-batch-size 16 # avoid OOM + --calculate-per-token-loss + # --micro-batch-size 16 + # --qkv-format bshd --use-dynamic-batch-size --max-tokens-per-gpu 9216 diff --git a/scripts/training/multimodal/run-qwen35-35B-A3B-8xgpu.sh b/scripts/training/multimodal/run-qwen35-35B-A3B-8xgpu.sh index e8e3b9df..be3592a6 100644 --- a/scripts/training/multimodal/run-qwen35-35B-A3B-8xgpu.sh +++ b/scripts/training/multimodal/run-qwen35-35B-A3B-8xgpu.sh @@ -46,6 +46,7 @@ ROLLOUT_ARGS=( --rollout-max-prompt-len 2048 --rollout-temperature 1 --global-batch-size 256 + --use-streaming-dataset --balance-data --use-fault-tolerance --system-prompt "${SYSTEM_PROMPT}" diff --git a/scripts/training/multimodal/run-qwen35-9B-8xgpu-openr1mm-async.sh b/scripts/training/multimodal/run-qwen35-9B-8xgpu-openr1mm-async.sh index 72c837bf..0561fa9d 100644 --- a/scripts/training/multimodal/run-qwen35-9B-8xgpu-openr1mm-async.sh +++ b/scripts/training/multimodal/run-qwen35-9B-8xgpu-openr1mm-async.sh @@ -63,10 +63,10 @@ ROLLOUT_ARGS=( ) PERF_ARGS=( - --tensor-model-parallel-size 4 + --tensor-model-parallel-size 2 --sequence-parallel --pipeline-model-parallel-size 1 - --context-parallel-size 1 + --context-parallel-size 2 --expert-model-parallel-size 1 --expert-tensor-parallel-size 1 @@ -74,12 +74,11 @@ PERF_ARGS=( --recompute-method uniform --recompute-num-layers 1 - # qwen3.5 only - --qkv-format bshd - --micro-batch-size 1 - #--micro-batch-size 16 # avoid OOM - # --use-dynamic-batch-size - --max-tokens-per-gpu 9216 + --calculate-per-token-loss + # --micro-batch-size 16 + # --qkv-format bshd + --use-dynamic-batch-size + --max-tokens-per-gpu 4096 --no-rope-fusion ) @@ -140,7 +139,7 @@ if [ ${MODE} = "async" ]; then --max-staleness 2 \ --num-data-storage-units 1 \ --num-iters-per-train-update 8 \ - --ref-actor-config '{"tensor_model_parallel_size": 1, "max_tokens_per_gpu": 16384, "sequence_parallel": false, "only_load_weight": true}' \ + --ref-actor-config '{"context_parallel_size": 1, "tensor_model_parallel_size": 1, "max_tokens_per_gpu": 16384, "sequence_parallel": false, "only_load_weight": true}' \ --fully-async \ --use-health-check \ "${MODEL_ARGS[@]}" \ diff --git a/scripts/training/multimodal/run-qwen35-9B-8xgpu-video.sh b/scripts/training/multimodal/run-qwen35-9B-8xgpu-video.sh index 505255e7..da2bcbd5 100755 --- a/scripts/training/multimodal/run-qwen35-9B-8xgpu-video.sh +++ b/scripts/training/multimodal/run-qwen35-9B-8xgpu-video.sh @@ -86,7 +86,7 @@ GRPO_ARGS=( --kl-loss-coef 0.001 --kl-loss-type low_var_kl --entropy-coef 0.00 - --eps-clip 3.0 + --eps-clip 0.2 --eps-clip-high 0.28 --use-tis ) @@ -101,7 +101,7 @@ OPTIMIZER_ARGS=( ) SGLANG_ARGS=( - --rollout-num-gpus-per-engine 2 + --rollout-num-gpus-per-engine 1 --sglang-mem-fraction-static 0.8 ) diff --git a/scripts/training/multimodal/run-qwen36-35B-A3B-8xgpu.sh b/scripts/training/multimodal/run-qwen36-35B-A3B-8xgpu.sh new file mode 100644 index 00000000..696c2a4c --- /dev/null +++ b/scripts/training/multimodal/run-qwen36-35B-A3B-8xgpu.sh @@ -0,0 +1,147 @@ +#!/bin/bash +# Copyright (c) 2026 Relax Authors. All Rights Reserved. +# +# Qwen3.6-35B-A3B 8xGPU colocate training script. +# +# Usage: +# bash scripts/training/multimodal/run-qwen36-35B-A3B-8xgpu.sh + +set -ex +set -o pipefail + +now=$(date "+%Y-%m-%d-%H:%M:%S") +echo "当前时间: $now" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +# Auto-source local environment when not launched via an external entrypoint +if [ -z "${RELAX_ENTRYPOINT_MODE:-}" ]; then + source "${SCRIPT_DIR}/../../entrypoint/local.sh" +fi +source "${MODEL_CONFIG_DIR}/qwen36-35B-A3B.sh" + +PROJECT_NAME="${PROJECT_NAME:=Relax/dev/openr1mm}" +EXP_DIR="${MODEL_DIR:=${SCRIPT_DIR}/../../../../exps}" +NUM_ROLLOUT="${NUM_ROLLOUT:=200}" + +CKPT_ARGS=( + --hf-checkpoint ${EXP_DIR}/Qwen3.6-35B-A3B + --ref-load ${EXP_DIR}/Qwen3.6-35B-A3B + --megatron-to-hf-mode bridge + + --load ${EXP_DIR}/save/Qwen3.6-35B_mcore_8xgpu/ + --save ${EXP_DIR}/save/Qwen3.6-35B_mcore_8xgpu/ + --max-actor-ckpt-to-keep 1 + --save-interval 100 +) + +PROMPT_SET="${EXP_DIR}/multimodal-open-r1-8k-verified/data/train-00000-of-00001_converted_noextract.parquet" +SYSTEM_PROMPT="A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here " + +ROLLOUT_ARGS=( + --prompt-data ${PROMPT_SET} + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type openr1mm + --num-rollout ${NUM_ROLLOUT} + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 2048 + --rollout-max-prompt-len 2048 + --rollout-temperature 1 + --global-batch-size 256 + --use-streaming-dataset + --balance-data + --use-fault-tolerance + --system-prompt "${SYSTEM_PROMPT}" + --multimodal-keys '{"image":"image"}' + --no-rope-fusion +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 2 + --context-parallel-size 1 + --expert-model-parallel-size 4 + --expert-tensor-parallel-size 1 + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + # --qkv-format bshd + # --micro-batch-size 1 # avoid OOM + --use-dynamic-batch-size + --max-tokens-per-gpu 6144 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 + --use-tis +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer + + # NOTE(wuhuan): to avoid algorithm performance degradation + --moe-router-load-balancing-type "none" + --moe-aux-loss-coeff 0.0 +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 2 + --sglang-mem-fraction-static 0.8 +) + +WANDB_ARGS=( + --use-clearml + --use-metrics-service + --tb-project-name ${PROJECT_NAME} + --tb-experiment-name qwen36-35B-A3B-${now} +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash +) + +mkdir -p log + +ray job submit ${RAY_NO_WAIT:+--no-wait} --address="http://127.0.0.1:8265" \ + ${WORKING_DIR:+--working-dir "${WORKING_DIR}"} \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 -m relax.entrypoints.train \ + --resource '{"actor": [1, 8], "rollout": [1, 8]}'\ + --max-staleness 0 \ + --num-data-storage-units 1 \ + --colocate \ + "${MODEL_ARGS[@]}" \ + "${CKPT_ARGS[@]}" \ + "${ROLLOUT_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${GRPO_ARGS[@]}" \ + "${WANDB_ARGS[@]}" \ + "${PERF_ARGS[@]}" \ + "${SGLANG_ARGS[@]}" \ + "${MISC_ARGS[@]}" 2>&1 | tee log/qwen36-35B-A3B-GRPO-gpu8-${now}.log diff --git a/scripts/training/text/run-llama3-8b-2xgpu-debug.sh b/scripts/training/text/run-llama3-8b-2xgpu-debug.sh deleted file mode 100644 index 4ce8cbb4..00000000 --- a/scripts/training/text/run-llama3-8b-2xgpu-debug.sh +++ /dev/null @@ -1,90 +0,0 @@ -#!/bin/bash - -# Copyright (c) 2026 Relax Authors. All Rights Reserved. -# -# Real-model 2xGPU training smoke for Meta-Llama-3-8B-Instruct on ROCm/CUDA hosts. - -set -ex -set -o pipefail - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/../../entrypoint/device_env.sh" - -if [ -z "$(relax_visible_devices)" ]; then - SELECTED_GPUS="$(relax_select_top_gpus_by_free_mem 2)" - if [ -n "${SELECTED_GPUS}" ]; then - relax_export_visible_devices "${SELECTED_GPUS}" - fi -fi - -export RAY_TMPDIR="${RAY_TMPDIR:=/tmp/ray-relax-llama3-smoke-$$}" -export RELAX_SERVE_PORT="${RELAX_SERVE_PORT:=18080}" - -if [ -z "${RELAX_ENTRYPOINT_MODE:-}" ]; then - source "${SCRIPT_DIR}/../../entrypoint/local.sh" -fi -source "${MODEL_CONFIG_DIR}/llama3-8B.sh" - -DATE="$(date +%Y%m%d_%H%M%S)" -ASSET_DIR="${RELAX_SMOKE_ASSET_DIR:=/tmp/relax-real-llama3-smoke}" -WORK_DIR="${RELAX_SMOKE_WORK_DIR:=${ASSET_DIR}/run-${DATE}}" -REAL_HF_MODEL_DIR="${REAL_HF_MODEL_DIR:=/mnt/dcgpuval/models/meta-llama/Meta-Llama-3-8B-Instruct}" - -mkdir -p "${WORK_DIR}" log - -if [ ! -f "${REAL_HF_MODEL_DIR}/config.json" ]; then - echo "REAL_HF_MODEL_DIR does not point to a HuggingFace checkpoint: ${REAL_HF_MODEL_DIR}" >&2 - exit 1 -fi - -python3 "${SCRIPT_DIR}/../../tools/create_tiny_qwen2_smoke_assets.py" --output-dir "${ASSET_DIR}" -DEBUG_ROLLOUT_PATH="${ASSET_DIR}/debug_rollout/{rollout_id}.pt" - -TRAIN_ARGS=( - --debug-train-only - --resource '{"actor": [1, 2]}' - --num-gpus-per-node 2 - --actor-num-gpus-per-node 2 - --rollout-batch-size 2 - --n-samples-per-prompt 2 - --global-batch-size 4 - --micro-batch-size 1 - --num-rollout 1 - --load-debug-rollout-data "${DEBUG_ROLLOUT_PATH}" - --save-debug-train-data "${WORK_DIR}/train_dump/{rollout_id}_{rank}.pt" - --save "${WORK_DIR}/ckpt" - --save-interval 1 - --async-save - --hf-checkpoint "${REAL_HF_MODEL_DIR}" - --ref-load "${REAL_HF_MODEL_DIR}" - --megatron-to-hf-mode bridge - --load "${REAL_HF_MODEL_DIR}" - --model-name llama - --transformer-impl local - --advantage-estimator grpo - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.0 - --adam-beta1 0.9 - --adam-beta2 0.95 - --tensor-model-parallel-size 1 - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - --qkv-format bshd - --attention-backend unfused - --no-masked-softmax-fusion - --no-rope-fusion - --no-bias-swiglu-fusion - --no-bias-dropout-fusion - --train-memory-margin-bytes 0 -) - -ray job submit ${RAY_NO_WAIT:+--no-wait} --address="http://127.0.0.1:8265" \ - ${WORKING_DIR:+--working-dir "${WORKING_DIR}"} \ - --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 -m relax.entrypoints.train \ - "${MODEL_ARGS[@]}" \ - "${TRAIN_ARGS[@]}" 2>&1 | tee "log/llama3-8b-real-debug-gpu2-${DATE}.log" diff --git a/scripts/training/text/run-qwen3-4B-8xgpu.sh b/scripts/training/text/run-qwen3-4B-8xgpu.sh index e2715c0a..c5d902f4 100644 --- a/scripts/training/text/run-qwen3-4B-8xgpu.sh +++ b/scripts/training/text/run-qwen3-4B-8xgpu.sh @@ -70,8 +70,8 @@ EVAL_ARGS=( PERF_ARGS=( --tensor-model-parallel-size 2 --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 + --pipeline-model-parallel-size 2 + --context-parallel-size 2 --expert-model-parallel-size 1 --expert-tensor-parallel-size 1 @@ -79,6 +79,7 @@ PERF_ARGS=( --recompute-method uniform --recompute-num-layers 1 + --calculate-per-token-loss #--micro-batch-size 16 # avoid OOM --use-dynamic-batch-size --max-tokens-per-gpu 9216 diff --git a/scripts/training/text/run-qwen35-35B-A3B-16xgpu.sh b/scripts/training/text/run-qwen35-35B-A3B-16xgpu.sh index 712bea15..95d19976 100755 --- a/scripts/training/text/run-qwen35-35B-A3B-16xgpu.sh +++ b/scripts/training/text/run-qwen35-35B-A3B-16xgpu.sh @@ -66,11 +66,11 @@ EVAL_ARGS=( ) PERF_ARGS=( - --tensor-model-parallel-size 2 + --tensor-model-parallel-size 4 --sequence-parallel --pipeline-model-parallel-size 2 --context-parallel-size 1 - --expert-model-parallel-size 4 + --expert-model-parallel-size 8 --expert-tensor-parallel-size 1 --recompute-granularity full diff --git a/scripts/training/text/run-qwen35-9B-8xgpu-async.sh b/scripts/training/text/run-qwen35-9B-8xgpu-async.sh index 071b9601..ece31dd6 100755 --- a/scripts/training/text/run-qwen35-9B-8xgpu-async.sh +++ b/scripts/training/text/run-qwen35-9B-8xgpu-async.sh @@ -77,10 +77,9 @@ PERF_ARGS=( --recompute-method uniform --recompute-num-layers 1 - # --use-dynamic-batch-size - # --max-tokens-per-gpu 10240 - --micro-batch-size 1 # avoid OOM - --qkv-format bshd + --use-dynamic-batch-size + --max-tokens-per-gpu 10240 + # --micro-batch-size 1 # avoid OOM --no-rope-fusion ) diff --git a/scripts/training/text/run-qwen3-4B-2xgpu.sh b/scripts/training/text/run-qwen35-9B-8xgpu.sh similarity index 61% rename from scripts/training/text/run-qwen3-4B-2xgpu.sh rename to scripts/training/text/run-qwen35-9B-8xgpu.sh index b3235400..d404fdba 100644 --- a/scripts/training/text/run-qwen3-4B-2xgpu.sh +++ b/scripts/training/text/run-qwen35-9B-8xgpu.sh @@ -2,14 +2,17 @@ # Copyright (c) 2026 Relax Authors. All Rights Reserved. # -# Qwen3-4B 2xGPU colocate training script. +# Qwen3.5-9B 8xGPU colocate (sync) training script for DAPO math dataset. # # Usage: -# NUM_GPUS=2 bash scripts/training/text/run-qwen3-4B-2xgpu.sh +# bash scripts/training/text/run-qwen35-9B-8xgpu.sh set -ex set -o pipefail +now=$(date "+%Y-%m-%d-%H:%M:%S") +echo "当前时间: $now" + SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" source "${SCRIPT_DIR}/../../entrypoint/device_env.sh" @@ -25,51 +28,55 @@ fi if [ -z "${RELAX_ENTRYPOINT_MODE:-}" ]; then source "${SCRIPT_DIR}/../../entrypoint/local.sh" fi -source "${MODEL_CONFIG_DIR}/qwen3-4B.sh" -# Support setting env from outside -EXP_DIR="${MODEL_DIR:=/root/exps}" +source "${MODEL_CONFIG_DIR}/qwen35-9B.sh" + PROJECT_NAME="${PROJECT_NAME:=Relax/dev/dapo-math}" -DATE=$(date +%Y%m%d_%H%M%S) -NUM_ROLLOUT="${NUM_ROLLOUT:=4}" +EXP_DIR="${MODEL_DIR:=${SCRIPT_DIR}/../../../../exps}" +NUM_ROLLOUT="${NUM_ROLLOUT:=1000}" CKPT_ARGS=( - --hf-checkpoint ${EXP_DIR}/Qwen3-4B/ - --ref-load ${EXP_DIR}/Qwen3-4B/ + --hf-checkpoint ${EXP_DIR}/Qwen3.5-9B + --ref-load ${EXP_DIR}/Qwen3.5-9B --megatron-to-hf-mode bridge - --load ${EXP_DIR}/Qwen3-4B_mcore/ - --save ${EXP_DIR}/Qwen3-4B_mcore/ - --save-interval 100 - --rotate-ckpt - --async-save + + --load ${EXP_DIR}/Qwen3-9B_mcore_8xgpu/ + --save ${EXP_DIR}/Qwen3-9B_mcore_8xgpu/ + --save-interval 50 + --max-actor-ckpt-to-keep 1 ) PROMPT_SET=${EXP_DIR}/dapo-math-17k/dapo-math-17k.jsonl ROLLOUT_ARGS=( - --use-streaming-dataset - --streaming-buffer-size 10000 --prompt-data ${PROMPT_SET} --input-key prompt --label-key label --apply-chat-template --rollout-shuffle - --rm-type dapo --reward-key score - --num-rollout ${NUM_ROLLOUT} - --rollout-batch-size 2 + --rollout-batch-size 32 --n-samples-per-prompt 8 - --rollout-max-response-len 2048 - --rollout-temperature 0.8 - - --global-batch-size 16 + --rollout-max-response-len 8192 + --rollout-temperature 1 + --global-batch-size 256 --balance-data --use-fault-tolerance ) +EVAL_ARGS=( + --log-passrate + --skip-eval-before-train + --eval-interval 20 + --eval-prompt-data aime ${EXP_DIR}/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 8 + --eval-max-response-len 8192 + --eval-top-p 0.7 +) + PERF_ARGS=( - --tensor-model-parallel-size 1 + --tensor-model-parallel-size 4 --sequence-parallel --pipeline-model-parallel-size 1 --context-parallel-size 1 @@ -80,9 +87,11 @@ PERF_ARGS=( --recompute-method uniform --recompute-num-layers 1 - # --micro-batch-size 1 - # --use-dynamic-batch-size - --max-tokens-per-gpu 9216 + --use-dynamic-batch-size + --max-tokens-per-gpu 10240 + # --micro-batch-size 1 # avoid OOM + + --no-rope-fusion ) GRPO_ARGS=( @@ -93,7 +102,6 @@ GRPO_ARGS=( --entropy-coef 0.00 --eps-clip 0.2 --eps-clip-high 0.28 - --use-tis ) @@ -104,22 +112,23 @@ OPTIMIZER_ARGS=( --weight-decay 0.1 --adam-beta1 0.9 --adam-beta2 0.98 + + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer ) SGLANG_ARGS=( - --rollout-num-gpus-per-engine 1 - --sglang-mem-fraction-static 0.7 + --rollout-num-gpus-per-engine 2 + --sglang-mem-fraction-static 0.8 + --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) ) WANDB_ARGS=( --use-clearml --use-metrics-service --tb-project-name ${PROJECT_NAME} - --tb-experiment-name qwen3-4b-GRPO-gpu2-${DATE} - # --use-wandb - # --wandb-project slime-dev - # --wandb-group qwen3-4B-test - # --wandb-key ${WANDB_KEY} + --tb-experiment-name qwen35-9B-8x-${now} ) MISC_ARGS=( @@ -134,13 +143,15 @@ MISC_ARGS=( ) mkdir -p log -ray job submit ${RAY_NO_WAIT:+--no-wait} --address="http://127.0.0.1:8265" \ +ray job submit ${RAY_NO_WAIT:+--no-wait} --address="http://${HOST_IP}:8265" \ ${WORKING_DIR:+--working-dir "${WORKING_DIR}"} \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 -m relax.entrypoints.train \ - --resource '{"actor": [1, 1], "rollout": [1, 1]}'\ + --resource '{"actor": [1, 8], "rollout": [1, 8]}' \ --max-staleness 0 \ - --num-data-storage-units 1 \ + --num-data-storage-units 1 \ + --colocate \ + --use-health-check \ "${MODEL_ARGS[@]}" \ "${CKPT_ARGS[@]}" \ "${ROLLOUT_ARGS[@]}" \ @@ -148,5 +159,6 @@ ray job submit ${RAY_NO_WAIT:+--no-wait} --address="http://127.0.0.1:8265" \ "${GRPO_ARGS[@]}" \ "${WANDB_ARGS[@]}" \ "${PERF_ARGS[@]}" \ + "${EVAL_ARGS[@]}" \ "${SGLANG_ARGS[@]}" \ - "${MISC_ARGS[@]}" 2>&1 | tee log/qwen3-4b-GRPO-gpu2-${DATE}.log + "${MISC_ARGS[@]}" 2>&1 | tee log/qwen35-9B-GRPO-gpu8-${now}.log diff --git a/tests/distributed/checkpoint_service/test_dcs_weight_conversion.py b/tests/distributed/checkpoint_service/test_dcs_weight_conversion.py index ffa8f41d..4a583a50 100644 --- a/tests/distributed/checkpoint_service/test_dcs_weight_conversion.py +++ b/tests/distributed/checkpoint_service/test_dcs_weight_conversion.py @@ -37,6 +37,12 @@ ExpertMLPDownProjMapping, ExpertMLPGateUpProjMapping, ) +from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import ( # noqa: E402 + ExpertMLPDownProjMapping as Qwen35ExpertMLPDownProjMapping, +) +from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import ( # noqa: E402 + ExpertMLPGateUpProjMapping as Qwen35ExpertMLPGateUpProjMapping, +) from relax.backends.megatron.misc_utils import strip_param_name_prefix # noqa: E402 from relax.backends.megatron.weight_conversion.processors import quantize_params, remove_padding # noqa: E402 @@ -77,7 +83,14 @@ def _noop_gather(self_m, megatron_weights, megatron_module, hf_param_name): return {str(hf_param_name): megatron_weights} saved_originals: dict = {} - patched_classes = [MegatronParamMapping, GatedMLPMapping, ExpertMLPGateUpProjMapping, ExpertMLPDownProjMapping] + patched_classes = [ + MegatronParamMapping, + GatedMLPMapping, + ExpertMLPGateUpProjMapping, + ExpertMLPDownProjMapping, + Qwen35ExpertMLPGateUpProjMapping, + Qwen35ExpertMLPDownProjMapping, + ] for cls in patched_classes: if "gather_from_ep_ranks" in cls.__dict__: saved_originals[cls] = cls.__dict__["gather_from_ep_ranks"] @@ -113,15 +126,35 @@ def _make_expert_down_mapping(layer_idx: int, expert_id: int) -> ExpertMLPDownPr return m +def _make_qwen35_expert_gate_up_mapping(layer_idx: int, expert_id: int) -> Qwen35ExpertMLPGateUpProjMapping: + """Create a real Qwen3.5 ExpertMLPGateUpProjMapping for testing.""" + return Qwen35ExpertMLPGateUpProjMapping( + megatron_param=f"language_model.decoder.layers.{layer_idx}.mlp.experts.linear_fc1.weight{expert_id}", + hf_param=f"model.language_model.layers.{layer_idx}.mlp.experts.gate_up_proj", + ) + + +def _make_qwen35_expert_down_mapping(layer_idx: int, expert_id: int) -> Qwen35ExpertMLPDownProjMapping: + """Create a real Qwen3.5 ExpertMLPDownProjMapping with eagerly initialized + inner mapping.""" + m = Qwen35ExpertMLPDownProjMapping( + megatron_param=f"language_model.decoder.layers.{layer_idx}.mlp.experts.linear_fc2.weight{expert_id}", + hf_param=f"model.language_model.layers.{layer_idx}.mlp.experts.down_proj", + ) + m._detected_type = "replicated" + m._mapping = m._get_or_create_mapping("replicated") + return m + + def _apply_expert_postprocessing( converted_dict: Dict[str, torch.Tensor], megatron_param_name: str, + bridge_expert_transposes_down: bool = True, ) -> List[Tuple[str, torch.Tensor]]: """Apply the same expert weight post-processing as ``_convert_to_hf_bridge``. - This calls the real production logic extracted from device_direct.py lines - 399-420. + Mirrors the production logic in device_direct.py. """ converted_named_tensors = list(converted_dict.items()) expert_id_match = re.search(r"weight(\d+)", megatron_param_name) @@ -130,14 +163,22 @@ def _apply_expert_postprocessing( postprocessed: list[tuple[str, torch.Tensor]] = [] for hf_name, tensor in converted_named_tensors: if hf_name.endswith(".experts.gate_up_proj"): - gate_tensor = tensor[0].transpose(-1, -2).contiguous() - up_tensor = tensor[1].transpose(-1, -2).contiguous() base = hf_name[: -len(".gate_up_proj")] + if tensor.ndim == 3: + gate_tensor = tensor[0].transpose(-1, -2).contiguous() + up_tensor = tensor[1].transpose(-1, -2).contiguous() + else: + gate_tensor, up_tensor = tensor.chunk(2, dim=0) postprocessed.append((f"{base}.{expert_id}.gate_proj.weight", gate_tensor)) postprocessed.append((f"{base}.{expert_id}.up_proj.weight", up_tensor)) elif hf_name.endswith(".experts.down_proj"): base = hf_name[: -len(".down_proj")] - postprocessed.append((f"{base}.{expert_id}.down_proj.weight", tensor.transpose(-1, -2).contiguous())) + if tensor.ndim == 2 and not bridge_expert_transposes_down: + postprocessed.append((f"{base}.{expert_id}.down_proj.weight", tensor)) + else: + postprocessed.append( + (f"{base}.{expert_id}.down_proj.weight", tensor.transpose(-1, -2).contiguous()) + ) else: postprocessed.append((hf_name, tensor)) converted_named_tensors = postprocessed @@ -657,3 +698,88 @@ def test_element_count_preserved(self): postprocessed = _apply_expert_postprocessing(bridge_output, "decoder.layers.0.mlp.experts.linear_fc1.weight0") total_numel = sum(t.numel() for _, t in postprocessed) assert total_numel == original_numel + + +# ─── Tests for Qwen3.5 Bridge (2D cat, no transpose) ───────────────────────── + + +class TestQwen35BridgeMappingOutput: + """Test Qwen3.5 Bridge mapping output format (2D cat, no transpose).""" + + def test_qwen35_gate_up_outputs_2d_cat(self): + """Qwen3.5 ExpertMLPGateUpProjMapping outputs 2D [2*H, D] via cat.""" + with _patch_gather_from_ep_ranks(): + m = _make_qwen35_expert_gate_up_mapping(layer_idx=0, expert_id=3) + H, D = 768, 2048 + fused = torch.randn(H * 2, D) + result = m.megatron_to_hf(fused, None) + + key = "model.language_model.layers.0.mlp.experts.gate_up_proj" + assert list(result.keys()) == [key] + tensor = result[key] + assert tensor.ndim == 2 + assert tensor.shape == (H * 2, D) + + def test_qwen35_down_proj_no_transpose(self): + """Qwen3.5 ExpertMLPDownProjMapping does not transpose.""" + with _patch_gather_from_ep_ranks(): + m = _make_qwen35_expert_down_mapping(layer_idx=0, expert_id=3) + D, H = 2048, 768 + param = torch.randn(D, H) + result = m.megatron_to_hf(param, None) + + key = "model.language_model.layers.0.mlp.experts.down_proj" + assert list(result.keys()) == [key] + tensor = result[key] + assert tensor.shape == (D, H) + assert torch.allclose(tensor, param) + + def test_qwen35_expert_transposes_down_detection(self): + """Qwen3.5 ExpertMLPDownProjMapping lacks megatron_to_hf override.""" + assert "megatron_to_hf" not in Qwen35ExpertMLPDownProjMapping.__dict__ + assert "megatron_to_hf" in ExpertMLPDownProjMapping.__dict__ + + +class TestQwen35PostProcessingCorrectness: + """Verify Qwen3.5 Bridge output + post-processing produces correct HF + weights.""" + + def test_qwen35_gate_up_postprocessed(self): + """Qwen3.5 gate_up 2D + post-processing produces correct gate/up.""" + H, D = 768, 2048 + expert_id = 3 + megatron_param = torch.randn(H * 2, D) + expected_gate, expected_up = megatron_param.chunk(2, dim=0) + + with _patch_gather_from_ep_ranks(): + mapping = _make_qwen35_expert_gate_up_mapping(layer_idx=0, expert_id=expert_id) + bridge_output = mapping.megatron_to_hf(megatron_param, None) + + megatron_name = f"language_model.decoder.layers.0.mlp.experts.linear_fc1.weight{expert_id}" + postprocessed = _apply_expert_postprocessing(bridge_output, megatron_name, bridge_expert_transposes_down=False) + + assert len(postprocessed) == 2 + assert postprocessed[0][0].endswith(f".experts.{expert_id}.gate_proj.weight") + assert postprocessed[1][0].endswith(f".experts.{expert_id}.up_proj.weight") + assert postprocessed[0][1].shape == (H, D) + assert postprocessed[1][1].shape == (H, D) + assert torch.allclose(postprocessed[0][1], expected_gate) + assert torch.allclose(postprocessed[1][1], expected_up) + + def test_qwen35_down_proj_postprocessed(self): + """Qwen3.5 down_proj passthrough (no transpose undo).""" + D, H = 2048, 768 + expert_id = 5 + megatron_param = torch.randn(D, H) + + with _patch_gather_from_ep_ranks(): + mapping = _make_qwen35_expert_down_mapping(layer_idx=0, expert_id=expert_id) + bridge_output = mapping.megatron_to_hf(megatron_param, None) + + megatron_name = f"language_model.decoder.layers.0.mlp.experts.linear_fc2.weight{expert_id}" + postprocessed = _apply_expert_postprocessing(bridge_output, megatron_name, bridge_expert_transposes_down=False) + + assert len(postprocessed) == 1 + assert postprocessed[0][0].endswith(f".experts.{expert_id}.down_proj.weight") + assert postprocessed[0][1].shape == (D, H) + assert torch.allclose(postprocessed[0][1], megatron_param) diff --git a/tests/utils/test_device_utils.py b/tests/utils/test_device_utils.py deleted file mode 100644 index a8ed2c51..00000000 --- a/tests/utils/test_device_utils.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) 2026 Relax Authors. All Rights Reserved. - -from relax.utils.device_utils import get_visible_devices, get_visible_devices_env_name, to_local_visible_device_index - - -def test_prefers_cuda_visible_devices(): - env = { - "CUDA_VISIBLE_DEVICES": "4,5", - "ROCR_VISIBLE_DEVICES": "1,2", - } - - assert get_visible_devices_env_name(env) == "CUDA_VISIBLE_DEVICES" - assert get_visible_devices(env) == ["4", "5"] - - -def test_falls_back_to_rocr_visible_devices(): - env = { - "ROCR_VISIBLE_DEVICES": "6,7", - } - - assert get_visible_devices_env_name(env) == "ROCR_VISIBLE_DEVICES" - assert get_visible_devices(env) == ["6", "7"] - - -def test_maps_physical_id_to_local_index(): - env = { - "ROCR_VISIBLE_DEVICES": "4,6,7", - } - - assert to_local_visible_device_index(6, env) == 1 - - -def test_accepts_local_index_when_already_remapped(): - env = { - "HIP_VISIBLE_DEVICES": "4,6,7", - } - - assert to_local_visible_device_index(2, env) == 2 diff --git a/third_party/TransformerEngine b/third_party/TransformerEngine deleted file mode 160000 index 72aab8e3..00000000 --- a/third_party/TransformerEngine +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 72aab8e34e002369f25d3fba213406f4630f6c02