Skip to content

perf: optimize logits allgather and parallelize eagle3 input projection#295

Merged
syuoni merged 11 commits into
lightseekorg:mainfrom
syuoni:perf/logits-allgather
May 29, 2026
Merged

perf: optimize logits allgather and parallelize eagle3 input projection#295
syuoni merged 11 commits into
lightseekorg:mainfrom
syuoni:perf/logits-allgather

Conversation

@syuoni
Copy link
Copy Markdown
Member

@syuoni syuoni commented May 28, 2026

Summary

Hidden-dim allgather fast path that beats NCCL on small token counts, wired through ColumnParallelLinear + logits projection, and EAGLE3 input projection parallelized.

  • New kernel: all_gather_inner (Triton, NVIDIA-only) — multimem.st-based push allgather over the hidden axis, with skip_entry_sync constexpr to drop the entry CAS barrier when the caller has externally guaranteed peer drain.
  • Backend wiring: TritonRSAGBackend.all_gather(dim=-1) dispatches to all_gather_inner on NVIDIA; AutoBackend routes 2-D hidden-dim allgathers through it. Comm-backend signatures cleaned up — rank is now derived internally via group.index(dist.get_rank()).
  • Logits: LogitsProcessor uses all_gather_inner for T ≤ 128, NCCL fallback above.
  • EAGLE3 fc: Eagle3MlaModel.fc (DSv3) and Eagle3LlamaModel.fc (Llama) switched from ReplicatedLinear / nn.Linear to ColumnParallelLinear(gather_output=True) — sharding the fc weight across attn TP. Expected fc time: ~50 μs → ~22 μs; saves ~231 MB/rank at TP=4.
  • Cleanup: deleted the now-dead simple_all_gather stack — FusionOp.AG_VOCAB, allgather_vocab, the TRT-LLM workspace plumbing, the CUDA kernel (all_gather.cu + flashinfer simple_all_gather namespace), Python bindings, and the kernel test.

@syuoni syuoni requested a review from a team as a code owner May 28, 2026 09:56
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 50603a6bba

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread python/tokenspeed/runtime/distributed/comm_ops.py
Comment thread python/tokenspeed/runtime/distributed/comm_backend/auto.py
Comment thread python/tokenspeed/runtime/layers/logits_processor.py
@syuoni syuoni force-pushed the perf/logits-allgather branch from 50603a6 to 2249284 Compare May 28, 2026 10:02
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 22492842c2

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread python/tokenspeed/runtime/layers/logits_processor.py Outdated
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 2e011953b5

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread python/tokenspeed/runtime/distributed/comm_ops.py
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 50056a0d89

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread python/tokenspeed/runtime/distributed/comm_backend/triton_rsag.py
qiangyicheng and others added 11 commits May 29, 2026 05:32
… skippable entry sync

Adds a new collective entry point ``all_gather_inner`` next to the existing
``all_gather``. Unlike the outer (which gathers along the token dim and
stages input through ``state.comm_buff``), the inner:

* Concatenates along the **hidden** dim — each rank contributes a column
  slab ``(total_tokens, hidden_list_in_group[rank])`` and the kernel
  multimem-broadcasts it into ``state.comm_buff`` at this rank's column
  band. Result is ``(total_tokens, sum(hidden_list_in_group))``.
* Reads ``hidden_states`` directly via the kernel's input pointer; no
  staging copy into ``state.comm_buff``. (The local rank's slot is
  populated by ``multimem.st``-to-self.)
* Accepts ``skip_entry_sync: bool`` (compile-time constexpr in the
  kernel) that elides the entry CAS barrier when the caller can
  externally guarantee cross-rank synchronization since the last
  buffer read.
* Uses CAS-based ``blockwise_barrier`` for both barriers (entry
  conditional via SKIP_ENTRY_SYNC constexpr, exit always). No change
  to barrier semantics, signal pad sizing, or symm-mem handle
  ownership.
* Mirrors the outer's ``rsag_resize_hidden_if_needed`` resize trick so
  a state oversized vs the active ``total_hidden`` returns a
  contiguous slice instead of a strided view.

API parallel to the outer: ``tp_hidden_dim`` (auto even-split, refuses
remainder distribution because per-rank slices must be 8-aligned) or
``hidden_list_in_group`` (explicit per-rank widths). Each list entry
must be ``> 0`` and a multiple of 8 bf16 (16-byte multimem.st
alignment). Input must be contiguous bf16 with a 16-byte-aligned
``data_ptr()``; ``state.hidden_dim`` must be a multiple of 8.

NVIDIA-only; AMD is intentionally unsupported on this path.

Acks two rounds of codex review covering input/state alignment,
zero-width shard rejection, and the precise skip_entry_sync safety
contract.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Yicheng Qiang <qiangyicheng@icloud.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
@syuoni syuoni force-pushed the perf/logits-allgather branch from 50056a0 to dbfcafe Compare May 29, 2026 06:15
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: dbfcafed6f

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread python/tokenspeed/runtime/layers/logits_processor.py
@syuoni
Copy link
Copy Markdown
Member Author

syuoni commented May 29, 2026

The timeout of PR Test / ut-runtime-1gpu / linux-mi355-1gpu-lightseek (pull_request) is a pre-existing issue: #217 (comment)

Merging.

@syuoni syuoni merged commit 5241fd9 into lightseekorg:main May 29, 2026
108 of 122 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants