Skip to content

feat(engine): compute_log_probs API for RL sequence scoring (RL-plan M2)#321

Open
HJSang wants to merge 1 commit into
mainfrom
hejian/rl_api
Open

feat(engine): compute_log_probs API for RL sequence scoring (RL-plan M2)#321
HJSang wants to merge 1 commit into
mainfrom
hejian/rl_api

Conversation

@HJSang
Copy link
Copy Markdown
Collaborator

@HJSang HJSang commented May 30, 2026

Summary

Add compute_log_probs — the sequence-scoring primitive online-RL trainers (PPO, GRPO, any KL-penalised objective) need to score prompt + completion token sequences under the engine's current weights. Implements Milestone 2 of the RL plan.

out = engine.compute_log_probs(
    sequences=[{"prompt_token_ids": [...], "completion_token_ids": [...]}],
    temperature=1.0,
)
# out["log_probs"][i][j] == log P(completion_token_ids[i][j] | context)
# out["tokens"][i]       == completion_token_ids[i]

Approach

The API is a thin wrapper, but GPU validation revealed the engine did not actually compute prompt/input-token logprobsgenerate(return_logprob=True, logprob_start_len=N) always returned an empty input_token_logprobs (the extend_return_logprob extraction in logits_processor.py was never enabled, and the C++ scheduler had no logprob awareness). So this PR adds prompt-logprob support to the engine end-to-end, then layers the RL-facing API on top:

  1. C++ scheduler (additive)RequestSpec.logprob_start_lenRequest; per prefill-chunk extend_logprob_start_len = clamp(logprob_start_len − already_scheduled_len, 0, extend_len) (handles cached prefix + chunked prefill) exposed on FlatForwardOp. Recompiles via pip install tokenspeed-scheduler/.
  2. Executor — builds the LogitsMetadata input-logprob fields (extend_return_logprob, start/seq/pruned lens, shifted target ids) for pure-extend scoring batches, threads them via ForwardContext, and surfaces LogitsProcessorOutput.input_token_logprobs through forward_stepModelExecutionResult. The existing extraction math (logits_processor.py:242-403) is now reachable.
  3. Outputgeneration_output_processor attaches per-request input logprobs and fills BatchTokenIDOut (previously a hard-coded empty stub); downstream meta_info wiring already existed.
  4. APIEngine.compute_log_probs builds a forward-only generate call (return_logprob=True, logprob_start_len=len(prompt)−1) and keeps the first M returned logprobs (the engine appends one trailing sampled-position entry). The pure request-building / response-parsing logic lives in a GPU-free helper module covered by fast CPU unit tests.

Files changed

File Change
tokenspeed-scheduler/csrc/scheduler/{request_spec.h,request.h,request.cpp} logprob_start_len on RequestSpec/Request + accessor
tokenspeed-scheduler/csrc/scheduler/operations/forward.{h,cpp} per-chunk extend_logprob_start_len + SoA on FlatForwardOp
tokenspeed-scheduler/bindings/python_module.cpp nanobind exposure of the new fields
python/tokenspeed/runtime/engine/scheduler_utils.py, request_handler.py thread logprob_start_len into make_spec
python/tokenspeed/runtime/execution/{context.py,model_executor.py,cuda_graph_wrapper.py,types.py} build input-logprob metadata; surface input_token_logprobs through forward_step
python/tokenspeed/runtime/layers/logits_processor.py from_forward_context copies the new fields onto LogitsMetadata
python/tokenspeed/runtime/engine/generation_output_processor.py attach per-request input logprobs; fill BatchTokenIDOut
python/tokenspeed/runtime/engine/compute_log_probs.py (new) pure helpers: build_score_kwargs, extract_completion_logprobs, compute_log_probs_core
python/tokenspeed/runtime/entrypoints/{engine.py,engine_base.py} Engine.compute_log_probs(sequences, temperature=1.0) + abstract method
test/runtime/test_compute_log_probs.py (new) 14 CPU unit tests + 1 GPU-gated integration test
docs/guides/compute-log-probs.md (new) + sidebar usage doc

Validation (B200, from-branch image build)

Test Model Result
CPU unit tests 14 passed (PYTHONPATH=python pytest test/runtime/test_compute_log_probs.py)
End-to-end Qwen3-0.6B PASScompute_log_probs([{[1,2,3,4]→[5,6,7]},{[10,11]→[12]}]) → tokens [[5,6,7],[12]], valid logprobs
Determinism (score twice) Qwen3-0.6B PASS — max|Δ|=0.0 (bitwise)
Suffix/causal-invariance (append completion) Qwen3-0.6B PASS — max|Δ|=0.0 (earlier logprobs unchanged)
Cross-engine vs vLLM prompt_logprobs (same model + token ids, bf16) Qwen3-0.6B PASS — max|Δ|=0.073, mean=0.035 (within bf16 cross-engine tolerance)
TP-invariance tp1 vs tp2 (validates the TP logit all-gather) Qwen3-0.6B PASS — max|Δ|=0.032 (bf16 reduction-order)
MoE end-to-end (expert routing + gather) Qwen3-30B-A3B PASS — tokens align, valid, bitwise determinism

Determinism and suffix-invariance are bitwise; the cross-engine, TP, and MoE-vs-dense differences are all within expected bf16 numerical tolerance (~0.03–0.07) and track token-for-token, confirming correct off-by-one alignment, vocab indexing, normalization, the TP gather, and the MoE path.

Behavior / limits (v1)

  • log_probs[i][j] is the scalar log-prob of the realized completion token (gathered from the full distribution), not a per-vocab array.
  • Temperature: 1.0 only (raw log_softmax); other values raise NotImplementedError. Sampling-temperature scaling (for off-policy importance sampling) is a planned follow-up.
  • Speculative decoding: unsupported — raises a clear error.
  • Input logprobs are produced for pure-extend scoring batches (num_extends == bs); empty prompt / empty completion are rejected.

Intentionally out of scope (follow-ups)

  • HTTP / SMG exposure — the Engine Python method is the only surface here; a native /compute_log_probs route + SMG gRPC land later when there's a consumer.
  • Temperature-scaled and top-k / full-distribution logprobs.
  • Concurrent multi-request batch-invariance check (single-request, TP, MoE, and the vLLM oracle are covered).

🤖 Generated with Claude Code

@HJSang HJSang requested a review from a team as a code owner May 30, 2026 17:15
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ff41c39519

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +91 to +92
entries = meta_info.get("input_token_logprobs")
if not entries or len(entries) != num_completion:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Wire input logprobs before parsing them

This API depends on meta_info["input_token_logprobs"], but the current generation output path constructs every BatchTokenIDOut with input_token_logprobs_val=[] and input_token_logprobs_idx=[] (python/tokenspeed/runtime/engine/generation_output_processor.py:807-808), so convert_logprob_style can only attach an empty list. For any non-empty completion this branch therefore raises ValueError instead of returning scores, making Engine.compute_log_probs unusable until the prefill input logprobs are actually propagated.

Useful? React with 👍 / 👎.

@lightseek-bot lightseek-bot requested a review from qywu May 30, 2026 17:53
@HJSang
Copy link
Copy Markdown
Collaborator Author

HJSang commented May 30, 2026

Update: engine now computes prompt/input-token logprobs (validated on B200)

GPU validation revealed the original generate-based approach couldn't work: this build never produced input/prompt-token logprobsextend_return_logprob was never enabled and the C++ scheduler had no logprob awareness, so input_token_logprobs was always empty. Rather than work around it, this PR now adds prompt-logprob support to the engine end-to-end:

  • C++ scheduler: RequestSpec.logprob_start_lenRequest; per-prefill-chunk extend_logprob_start_len = clamp(logprob_start_len − already_scheduled_len, 0, extend_len) exposed on FlatForwardOp (additive, recompiles via pip install tokenspeed-scheduler/).
  • Executor: builds LogitsMetadata input-logprob fields for pure-extend scoring batches, threads them via ForwardContext, and surfaces LogitsProcessorOutput.input_token_logprobs through forward_stepModelExecutionResult.
  • Output: generation_output_processor attaches per-request input logprobs and fills BatchTokenIDOut (the prior empty stub); downstream meta_info wiring already existed.
  • compute_log_probs: logprob_start_len = len(prompt) − 1; keeps the first M entries (drops the trailing sampled-position entry).

Validation (B200, Qwen3-0.6B, from-branch image build)

  • Sweep confirmed alignment: logprob_start_len=len(prompt)−1 → tokens [5,6,7,-1], completion logprobs are the first M.
  • End-to-end: Engine.compute_log_probs([{prompt:[1,2,3,4],completion:[5,6,7]},{prompt:[10,11],completion:[12]}])tokens [[5,6,7],[12]], log_probs [[-2.41,-3.81,-7.46],[-8.76]], temperature guard fires. ✅
  • 14 CPU unit tests pass (fakes updated to mirror the verified one-position shift).

Engine prerequisites discovered along the way (documented): needs CUDA graphs (eager MHA decode is unsupported), --shm-size/--ipc=host.

Remaining (follow-up): broader invariance matrix (batch/order/TP), vLLM cross-check, MoE model, and temperature-scaled logprobs.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 2e33dca92e

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +587 to +590
if (
req_input_lp_val is not None
and request_state.input_token_logprobs_val is not None
):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve input logprobs across prefill chunks

When a scoring request is chunked, non-final prefill forwards can already carry req_input_lp_val, but this new accumulation runs only after the existing prefill_finished guard. Those chunk results are therefore discarded instead of being stored on RequestState, so long prompt+completion sequences return only the final chunk's logprobs (or none) and compute_log_probs raises/incompletely scores. Accumulate the input logprobs before skipping chunked-prefill output.

Useful? React with 👍 / 👎.

Comment on lines +424 to +425
const std::int32_t rel = logprob_start_len - info.already_scheduled_len;
op.extend_logprob_start_len = (rel >= 0 && rel < info.extend_len) ? rel : -1;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep collecting logprobs after the first chunk

For any prefill chunk that starts after logprob_start_len, rel is negative, so this sets extend_logprob_start_len to -1 and disables input-logprob collection for the rest of the sequence. A long prompt+completion whose requested logprob window spans multiple chunks will therefore omit all tokens in later chunks even though they are still part of the completion being scored. Chunks before the requested window should be skipped, but chunks after the start should collect from offset 0.

Useful? React with 👍 / 👎.

@HJSang
Copy link
Copy Markdown
Collaborator Author

HJSang commented May 30, 2026

Validation matrix (B200, from-branch image)

Test Model Result
Determinism (score twice) Qwen3-0.6B PASS — max|Δ|=0.0 (bitwise)
Suffix/causal-invariance (append completion) Qwen3-0.6B PASS — max|Δ|=0.0 (earlier logprobs unchanged, bitwise)
Cross-engine vs vLLM prompt_logprobs (same model+token ids, bf16) Qwen3-0.6B PASS — max|Δ|=0.073, mean=0.035 (within bf16 cross-engine tolerance)
TP-invariance (tp1 vs tp2, validates the TP logit all-gather) Qwen3-0.6B PASS — max|Δ|=0.032 (bf16 reduction-order)
MoE end-to-end (expert routing + gather) Qwen3-30B-A3B PASS — tokens align, valid, bitwise determinism

Determinism and suffix-invariance are bitwise; cross-engine, TP, and MoE-vs-dense differences are all within expected bf16 numerical tolerance (~0.03–0.07) and track token-for-token, confirming correct off-by-one, vocab indexing, normalization, and TP gather. Prompt-logprob support is validated across dense + MoE and single + multi-GPU.

@HJSang HJSang force-pushed the hejian/rl_api branch 2 times, most recently from a9443ee to c41af08 Compare June 1, 2026 02:01
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c41af085c3

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +516 to +517
if not all(s >= 0 for s in start_lens):
return
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep input-logprob collection per request

When a scoring request is batched with any normal prefill request, that normal request contributes -1 to extend_logprob_start_lens, so this batch-wide all(...) disables input-logprob computation for every request in the pure-extend batch. The scheduler can coalesce independent prefills, and the new docs say scoring can be interleaved with normal generation; in that case compute_log_probs receives no input_token_logprobs and raises instead of scoring the valid request. Handle non-negative windows per request rather than requiring the entire batch to request input logprobs.

Useful? React with 👍 / 👎.

Expose Engine.compute_log_probs end-to-end so RL workflows can score
prompt and completion tokens. Builds on generate: adds a request-builder
helper, parses completion logprobs from generate meta_info, and adds a
compute_log_probs_core orchestration layer.

Adds cross-language support for input/prompt-token logprobs by wiring the
C++ scheduler, executor, and output path through to the Python engine.

Handle long sequences across chunked prefill: collect the input-logprob
window from every prefill chunk it overlaps. Previously the C++ scheduler
set extend_logprob_start_len=-1 once a chunk started past the window start
(rel<0), and the output processor discarded non-final prefill chunks, so a
prompt+completion split across chunks returned too few completion logprobs.
Now later in-window chunks collect from offset 0 and their logprobs are
accumulated before the chunked-prefill guard.

Documents the compute_log_probs API (including the prefill-only scoring
engine config) and adds GPU-gated tests: an end-to-end test, a
chunked-vs-single-chunk equivalence test for long sequences, and coverage
for validate_sequence, negative temperature, and the dual temperature
gates.

Validated on B200 (Qwen2-1.5B-Instruct): chunked (chunked_prefill_size=128,
512-token prompt + 40-token completion) matches the single-chunk reference
within bf16 tolerance (max abs diff 0.039); full suite 16/16 pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Hejian Sang <sanghj0923@gmail.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 585fc4a805

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +434 to +435
const std::int32_t rel = logprob_start_len - info.already_scheduled_len;
op.extend_logprob_start_len = (rel < info.extend_len) ? std::max(rel, 0) : -1;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Recompute the predictor token when prefix cache hits

When a scoring request reuses a cached prompt prefix, the first scheduled prefill window can begin after logprob_start_len (the first-chunk transition sets window.begin from the prefix-cache match depth). This branch then treats rel < 0 as “collect from offset 0”, but the hidden state that scores the first uncached token is the previous token in the cached prefix and is not present in this forward. For example, after a prior request caches a page-aligned prompt, scoring the same prompt plus a completion starts forwarding at the completion, so the returned input_token_logprobs begin with the second completion token (or a trailing -1) and compute_log_probs silently returns misaligned scores. The scoring path needs to disable prefix-cache reuse for the requested window or force scheduling/recomputing the token at logprob_start_len.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant