Skip to content

[fsdp, megatron, trainer] fix: enhance mem footprint for forward_kl_topk OPD#6848

Merged
Luosuu merged 2 commits into
verl-project:mainfrom
dimjava:feature/reduce-opd-mem
Jun 30, 2026
Merged

[fsdp, megatron, trainer] fix: enhance mem footprint for forward_kl_topk OPD#6848
Luosuu merged 2 commits into
verl-project:mainfrom
dimjava:feature/reduce-opd-mem

Conversation

@dimjava

@dimjava dimjava commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Reduce VRAM and compute for OPD (forward_kl_topk + use_policy_gradient=False + use_task_rewards=False) by skipping redundant full-vocab log_probs and PPO-loss work when only the top-k distillation loss is needed.

Introduces a distillation_only flag (set in the trainer, consumed by actor engines) and scopes empty_cache() on Megatron to top-k distillation steps.

Target config: distillation.distillation_loss.loss_mode=forward_kl_topk with supervised distillation

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, fully_async, one_step_off
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

pytest tests/workers/test_megatron_distillation_only_on_cpu.py -q
pytest tests/workers/test_distillation_topk_symmetry_on_cpu.py -q

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

@CLAassistant

CLAassistant commented Jun 25, 2026

Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


Dmitrii Choklia seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@dimjava dimjava changed the title [opd] enhance mem footprint for topk OPD [opd] enhance mem footprint for forward_kl_topk OPD Jun 25, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a distillation_only flag to optimize memory usage and skip policy loss calculations during supervised top-k distillation. The changes span the PPO trainers, distillation loss calculations, and both FSDP and Megatron transformer engine implementations to conditionally omit log_probs computation. A potential issue was identified in the FSDP engine implementation (verl/workers/engine/fsdp/transformer_impl.py), where a destructive in-place modification of logits during log-probability calculation could corrupt the logits before they are processed by the distillation logits processor.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread verl/workers/engine/fsdp/transformer_impl.py Outdated
@dimjava dimjava changed the title [opd] enhance mem footprint for forward_kl_topk OPD [opd, fsdp, megatron] enhance mem footprint for forward_kl_topk OPD Jun 25, 2026
@dimjava dimjava changed the title [opd, fsdp, megatron] enhance mem footprint for forward_kl_topk OPD [opd, fsdp, megatron] feat: enhance mem footprint for forward_kl_topk OPD Jun 25, 2026
@dimjava dimjava changed the title [opd, fsdp, megatron] feat: enhance mem footprint for forward_kl_topk OPD [opd, fsdp, megatron] fix: enhance mem footprint for forward_kl_topk OPD Jun 25, 2026
@dimjava

dimjava commented Jun 25, 2026

Copy link
Copy Markdown
Contributor Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a "distillation_only" mode to optimize memory footprint and computation when performing supervised top-k distillation without policy gradients or task rewards. It skips the calculation, gathering, and propagation of log_probs across FSDP and Megatron engines, and adds corresponding unit tests to verify this behavior. Additionally, it empties the PyTorch cache before the Megatron optimizer step when top-k distillation is active to mitigate potential OOM issues.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

@dimjava dimjava force-pushed the feature/reduce-opd-mem branch from 7482c90 to e18691f Compare June 29, 2026 08:06
@dimjava

dimjava commented Jun 29, 2026

Copy link
Copy Markdown
Contributor Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a distillation_only mode to skip policy loss computation and log probability calculation when performing supervised top-k distillation without task rewards or policy gradients. This optimization reduces memory footprint and prevents potential out-of-memory (OOM) errors, particularly on tight VRAM setups. Key changes include updating the FSDP and Megatron transformer engines to conditionally bypass log probability computation, clearing the PyTorch cache before optimizer steps in Megatron when top-k distillation is active, and adding corresponding unit tests to verify the new behavior. I have no feedback to provide as there are no review comments.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

@dimjava

dimjava commented Jun 29, 2026

Copy link
Copy Markdown
Contributor Author

@wucong25 could you pls approve for CI to run ?

@wuxibin89 wuxibin89 requested a review from Luosuu June 29, 2026 09:06

@Luosuu Luosuu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks

@Luosuu

Luosuu commented Jun 30, 2026

Copy link
Copy Markdown
Collaborator

@dimjava please fix CPU tests then feel free to @ me to re-trigger CI for you

@dimjava dimjava changed the title [opd, fsdp, megatron] fix: enhance mem footprint for forward_kl_topk OPD [fsdp, megatron, trainer] fix: enhance mem footprint for forward_kl_topk OPD Jun 30, 2026
@dimjava

dimjava commented Jun 30, 2026

Copy link
Copy Markdown
Contributor Author

@dimjava please fix CPU tests then feel free to @ me to re-trigger CI for you

It seems there is a problem with pypi index, some of already merged MRs into main failed as well

ERROR: Could not find a version that satisfies the requirement TransferQueue==0.1.8 (from versions: none)

@dimjava dimjava force-pushed the feature/reduce-opd-mem branch from 88908d6 to 1625f3d Compare June 30, 2026 11:46
@dimjava

dimjava commented Jun 30, 2026

Copy link
Copy Markdown
Contributor Author

@Luosuu pls run CI once again

@Luosuu Luosuu merged commit 1e3b6ff into verl-project:main Jun 30, 2026
94 of 124 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants