feat(engine): compute_log_probs API for RL sequence scoring (RL-plan M2)#321
feat(engine): compute_log_probs API for RL sequence scoring (RL-plan M2)#321HJSang wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ff41c39519
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| entries = meta_info.get("input_token_logprobs") | ||
| if not entries or len(entries) != num_completion: |
There was a problem hiding this comment.
Wire input logprobs before parsing them
This API depends on meta_info["input_token_logprobs"], but the current generation output path constructs every BatchTokenIDOut with input_token_logprobs_val=[] and input_token_logprobs_idx=[] (python/tokenspeed/runtime/engine/generation_output_processor.py:807-808), so convert_logprob_style can only attach an empty list. For any non-empty completion this branch therefore raises ValueError instead of returning scores, making Engine.compute_log_probs unusable until the prefill input logprobs are actually propagated.
Useful? React with 👍 / 👎.
Update: engine now computes prompt/input-token logprobs (validated on B200)GPU validation revealed the original
Validation (B200, Qwen3-0.6B, from-branch image build)
Engine prerequisites discovered along the way (documented): needs CUDA graphs (eager MHA decode is unsupported), Remaining (follow-up): broader invariance matrix (batch/order/TP), vLLM cross-check, MoE model, and temperature-scaled logprobs. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 2e33dca92e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if ( | ||
| req_input_lp_val is not None | ||
| and request_state.input_token_logprobs_val is not None | ||
| ): |
There was a problem hiding this comment.
Preserve input logprobs across prefill chunks
When a scoring request is chunked, non-final prefill forwards can already carry req_input_lp_val, but this new accumulation runs only after the existing prefill_finished guard. Those chunk results are therefore discarded instead of being stored on RequestState, so long prompt+completion sequences return only the final chunk's logprobs (or none) and compute_log_probs raises/incompletely scores. Accumulate the input logprobs before skipping chunked-prefill output.
Useful? React with 👍 / 👎.
| const std::int32_t rel = logprob_start_len - info.already_scheduled_len; | ||
| op.extend_logprob_start_len = (rel >= 0 && rel < info.extend_len) ? rel : -1; |
There was a problem hiding this comment.
Keep collecting logprobs after the first chunk
For any prefill chunk that starts after logprob_start_len, rel is negative, so this sets extend_logprob_start_len to -1 and disables input-logprob collection for the rest of the sequence. A long prompt+completion whose requested logprob window spans multiple chunks will therefore omit all tokens in later chunks even though they are still part of the completion being scored. Chunks before the requested window should be skipped, but chunks after the start should collect from offset 0.
Useful? React with 👍 / 👎.
Validation matrix (B200, from-branch image)
Determinism and suffix-invariance are bitwise; cross-engine, TP, and MoE-vs-dense differences are all within expected bf16 numerical tolerance (~0.03–0.07) and track token-for-token, confirming correct off-by-one, vocab indexing, normalization, and TP gather. Prompt-logprob support is validated across dense + MoE and single + multi-GPU. |
a9443ee to
c41af08
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c41af085c3
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if not all(s >= 0 for s in start_lens): | ||
| return |
There was a problem hiding this comment.
Keep input-logprob collection per request
When a scoring request is batched with any normal prefill request, that normal request contributes -1 to extend_logprob_start_lens, so this batch-wide all(...) disables input-logprob computation for every request in the pure-extend batch. The scheduler can coalesce independent prefills, and the new docs say scoring can be interleaved with normal generation; in that case compute_log_probs receives no input_token_logprobs and raises instead of scoring the valid request. Handle non-negative windows per request rather than requiring the entire batch to request input logprobs.
Useful? React with 👍 / 👎.
Expose Engine.compute_log_probs end-to-end so RL workflows can score prompt and completion tokens. Builds on generate: adds a request-builder helper, parses completion logprobs from generate meta_info, and adds a compute_log_probs_core orchestration layer. Adds cross-language support for input/prompt-token logprobs by wiring the C++ scheduler, executor, and output path through to the Python engine. Handle long sequences across chunked prefill: collect the input-logprob window from every prefill chunk it overlaps. Previously the C++ scheduler set extend_logprob_start_len=-1 once a chunk started past the window start (rel<0), and the output processor discarded non-final prefill chunks, so a prompt+completion split across chunks returned too few completion logprobs. Now later in-window chunks collect from offset 0 and their logprobs are accumulated before the chunked-prefill guard. Documents the compute_log_probs API (including the prefill-only scoring engine config) and adds GPU-gated tests: an end-to-end test, a chunked-vs-single-chunk equivalence test for long sequences, and coverage for validate_sequence, negative temperature, and the dual temperature gates. Validated on B200 (Qwen2-1.5B-Instruct): chunked (chunked_prefill_size=128, 512-token prompt + 40-token completion) matches the single-chunk reference within bf16 tolerance (max abs diff 0.039); full suite 16/16 pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Hejian Sang <sanghj0923@gmail.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 585fc4a805
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| const std::int32_t rel = logprob_start_len - info.already_scheduled_len; | ||
| op.extend_logprob_start_len = (rel < info.extend_len) ? std::max(rel, 0) : -1; |
There was a problem hiding this comment.
Recompute the predictor token when prefix cache hits
When a scoring request reuses a cached prompt prefix, the first scheduled prefill window can begin after logprob_start_len (the first-chunk transition sets window.begin from the prefix-cache match depth). This branch then treats rel < 0 as “collect from offset 0”, but the hidden state that scores the first uncached token is the previous token in the cached prefix and is not present in this forward. For example, after a prior request caches a page-aligned prompt, scoring the same prompt plus a completion starts forwarding at the completion, so the returned input_token_logprobs begin with the second completion token (or a trailing -1) and compute_log_probs silently returns misaligned scores. The scoring path needs to disable prefix-cache reuse for the requested window or force scheduling/recomputing the token at logprob_start_len.
Useful? React with 👍 / 👎.
Summary
Add
compute_log_probs— the sequence-scoring primitive online-RL trainers (PPO, GRPO, any KL-penalised objective) need to scoreprompt + completiontoken sequences under the engine's current weights. Implements Milestone 2 of the RL plan.Approach
The API is a thin wrapper, but GPU validation revealed the engine did not actually compute prompt/input-token logprobs —
generate(return_logprob=True, logprob_start_len=N)always returned an emptyinput_token_logprobs(theextend_return_logprobextraction inlogits_processor.pywas never enabled, and the C++ scheduler had no logprob awareness). So this PR adds prompt-logprob support to the engine end-to-end, then layers the RL-facing API on top:RequestSpec.logprob_start_len→Request; per prefill-chunkextend_logprob_start_len = clamp(logprob_start_len − already_scheduled_len, 0, extend_len)(handles cached prefix + chunked prefill) exposed onFlatForwardOp. Recompiles viapip install tokenspeed-scheduler/.LogitsMetadatainput-logprob fields (extend_return_logprob, start/seq/pruned lens, shifted target ids) for pure-extend scoring batches, threads them viaForwardContext, and surfacesLogitsProcessorOutput.input_token_logprobsthroughforward_step→ModelExecutionResult. The existing extraction math (logits_processor.py:242-403) is now reachable.generation_output_processorattaches per-request input logprobs and fillsBatchTokenIDOut(previously a hard-coded empty stub); downstreammeta_infowiring already existed.Engine.compute_log_probsbuilds a forward-onlygeneratecall (return_logprob=True,logprob_start_len=len(prompt)−1) and keeps the firstMreturned logprobs (the engine appends one trailing sampled-position entry). The pure request-building / response-parsing logic lives in a GPU-free helper module covered by fast CPU unit tests.Files changed
tokenspeed-scheduler/csrc/scheduler/{request_spec.h,request.h,request.cpp}logprob_start_lenonRequestSpec/Request+ accessortokenspeed-scheduler/csrc/scheduler/operations/forward.{h,cpp}extend_logprob_start_len+ SoA onFlatForwardOptokenspeed-scheduler/bindings/python_module.cpppython/tokenspeed/runtime/engine/scheduler_utils.py,request_handler.pylogprob_start_lenintomake_specpython/tokenspeed/runtime/execution/{context.py,model_executor.py,cuda_graph_wrapper.py,types.py}input_token_logprobsthroughforward_steppython/tokenspeed/runtime/layers/logits_processor.pyfrom_forward_contextcopies the new fields ontoLogitsMetadatapython/tokenspeed/runtime/engine/generation_output_processor.pyBatchTokenIDOutpython/tokenspeed/runtime/engine/compute_log_probs.py(new)build_score_kwargs,extract_completion_logprobs,compute_log_probs_corepython/tokenspeed/runtime/entrypoints/{engine.py,engine_base.py}Engine.compute_log_probs(sequences, temperature=1.0)+ abstract methodtest/runtime/test_compute_log_probs.py(new)docs/guides/compute-log-probs.md(new) + sidebarValidation (B200, from-branch image build)
PYTHONPATH=python pytest test/runtime/test_compute_log_probs.py)compute_log_probs([{[1,2,3,4]→[5,6,7]},{[10,11]→[12]}])→ tokens[[5,6,7],[12]], valid logprobsprompt_logprobs(same model + token ids, bf16)Determinism and suffix-invariance are bitwise; the cross-engine, TP, and MoE-vs-dense differences are all within expected bf16 numerical tolerance (~0.03–0.07) and track token-for-token, confirming correct off-by-one alignment, vocab indexing, normalization, the TP gather, and the MoE path.
Behavior / limits (v1)
log_probs[i][j]is the scalar log-prob of the realized completion token (gathered from the full distribution), not a per-vocab array.1.0only (rawlog_softmax); other values raiseNotImplementedError. Sampling-temperature scaling (for off-policy importance sampling) is a planned follow-up.num_extends == bs); empty prompt / empty completion are rejected.Intentionally out of scope (follow-ups)
EnginePython method is the only surface here; a native/compute_log_probsroute + SMG gRPC land later when there's a consumer.🤖 Generated with Claude Code