perf: optimize logits allgather and parallelize eagle3 input projection#295
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 50603a6bba
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
50603a6 to
2249284
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 22492842c2
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 2e011953b5
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 50056a0d89
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
… skippable entry sync Adds a new collective entry point ``all_gather_inner`` next to the existing ``all_gather``. Unlike the outer (which gathers along the token dim and stages input through ``state.comm_buff``), the inner: * Concatenates along the **hidden** dim — each rank contributes a column slab ``(total_tokens, hidden_list_in_group[rank])`` and the kernel multimem-broadcasts it into ``state.comm_buff`` at this rank's column band. Result is ``(total_tokens, sum(hidden_list_in_group))``. * Reads ``hidden_states`` directly via the kernel's input pointer; no staging copy into ``state.comm_buff``. (The local rank's slot is populated by ``multimem.st``-to-self.) * Accepts ``skip_entry_sync: bool`` (compile-time constexpr in the kernel) that elides the entry CAS barrier when the caller can externally guarantee cross-rank synchronization since the last buffer read. * Uses CAS-based ``blockwise_barrier`` for both barriers (entry conditional via SKIP_ENTRY_SYNC constexpr, exit always). No change to barrier semantics, signal pad sizing, or symm-mem handle ownership. * Mirrors the outer's ``rsag_resize_hidden_if_needed`` resize trick so a state oversized vs the active ``total_hidden`` returns a contiguous slice instead of a strided view. API parallel to the outer: ``tp_hidden_dim`` (auto even-split, refuses remainder distribution because per-rank slices must be 8-aligned) or ``hidden_list_in_group`` (explicit per-rank widths). Each list entry must be ``> 0`` and a multiple of 8 bf16 (16-byte multimem.st alignment). Input must be contiguous bf16 with a 16-byte-aligned ``data_ptr()``; ``state.hidden_dim`` must be a multiple of 8. NVIDIA-only; AMD is intentionally unsupported on this path. Acks two rounds of codex review covering input/state alignment, zero-width shard rejection, and the precise skip_entry_sync safety contract. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Yicheng Qiang <qiangyicheng@icloud.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
50056a0 to
dbfcafe
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: dbfcafed6f
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
The timeout of PR Test / ut-runtime-1gpu / linux-mi355-1gpu-lightseek (pull_request) is a pre-existing issue: #217 (comment) Merging. |
Summary
Hidden-dim allgather fast path that beats NCCL on small token counts, wired through
ColumnParallelLinear+ logits projection, and EAGLE3 input projection parallelized.all_gather_inner(Triton, NVIDIA-only) — multimem.st-based push allgather over the hidden axis, withskip_entry_syncconstexpr to drop the entry CAS barrier when the caller has externally guaranteed peer drain.TritonRSAGBackend.all_gather(dim=-1)dispatches toall_gather_inneron NVIDIA;AutoBackendroutes 2-D hidden-dim allgathers through it. Comm-backend signatures cleaned up —rankis now derived internally viagroup.index(dist.get_rank()).LogitsProcessorusesall_gather_innerfor T ≤ 128, NCCL fallback above.Eagle3MlaModel.fc(DSv3) andEagle3LlamaModel.fc(Llama) switched fromReplicatedLinear/nn.LineartoColumnParallelLinear(gather_output=True)— sharding the fc weight across attn TP. Expected fc time: ~50 μs → ~22 μs; saves ~231 MB/rank at TP=4.simple_all_gatherstack —FusionOp.AG_VOCAB,allgather_vocab, the TRT-LLM workspace plumbing, the CUDA kernel (all_gather.cu+ flashinfersimple_all_gathernamespace), Python bindings, and the kernel test.