diff --git a/docs/.vitepress/config.mts b/docs/.vitepress/config.mts index 1289eb32d..a5c9cbbb3 100644 --- a/docs/.vitepress/config.mts +++ b/docs/.vitepress/config.mts @@ -37,7 +37,8 @@ export default defineConfig({ text: "Guides", items: [ { text: "Getting Started", link: "/guides/getting-started" }, - { text: "Launching a Server", link: "/guides/launching" } + { text: "Launching a Server", link: "/guides/launching" }, + { text: "Computing Log-Probabilities", link: "/guides/compute-log-probs" } ] }, { diff --git a/docs/guides/compute-log-probs.md b/docs/guides/compute-log-probs.md new file mode 100644 index 000000000..7f0cafb55 --- /dev/null +++ b/docs/guides/compute-log-probs.md @@ -0,0 +1,68 @@ +# Computing Log-Probabilities (RL Scoring) + +`Engine.compute_log_probs` scores `prompt + completion` token sequences under the +engine's current weights and returns one log-probability per completion token. It +is the core scoring primitive for online-RL trainers (PPO, GRPO, and any +KL-penalised objective) — for example to form importance-sampling ratios against +the policy that generated the rollouts. + +## Usage + +```python +from tokenspeed.runtime.entrypoints.engine import Engine + +# Scoring runs a pure-extend (prefill-only) forward. On backends that cannot +# serve a mixed prefill+decode batch eagerly (e.g. the default `mha` backend), +# launch the engine for scoring with a backend + scheduler config that keeps the +# request on a pure-extend path: +engine = Engine( + model="", + attention_backend="flashinfer", + enforce_eager=True, + disable_overlap_schedule=True, +) + +out = engine.compute_log_probs( + sequences=[ + {"prompt_token_ids": [1, 2, 3, 4], "completion_token_ids": [5, 6, 7]}, + {"prompt_token_ids": [10, 11], "completion_token_ids": [12]}, + ], + temperature=1.0, +) + +# out["log_probs"][i][j] == log P(completion_token_ids[i][j] | context) +# out["tokens"][i] == completion_token_ids[i] +out["log_probs"] # e.g. [[-0.12, -0.47, -0.31], [-2.03]] +out["tokens"] # [[5, 6, 7], [12]] +``` + +`log_probs[i][j]` is the log-probability of the realised completion token `j` in +sequence `i`, conditioned on everything before it (prompt + earlier completion +tokens). Only completion positions are scored; the prompt is context. + +## How it works + +It reuses the normal generation path: internally each sequence is sent through a +forward-only `generate` call (`max_new_tokens=0`, `return_logprob=True`, +`logprob_start_len=len(prompt)`), and the per-token input logprobs are read back +from `meta_info["input_token_logprobs"]`. Logits are gathered across tensor-parallel +ranks before `log_softmax`, exactly as on the sampling path. No engine pause is +required; scoring requests can be interleaved with normal generation. + +Long sequences are handled across chunked prefill: when a `prompt + completion` +is split into multiple prefill chunks, the input-logprob window is collected from +every chunk it overlaps (not just the first), so the full set of completion +logprobs is returned regardless of `chunked_prefill_size`. + +## Limits (current) + +- **Temperature:** `temperature=1.0` only (raw `log_softmax`). Other values raise + `NotImplementedError`. Sampling-temperature scaling (for off-policy importance + sampling) is a planned follow-up. +- **Speculative decoding:** unavailable — `compute_log_probs` raises if the engine + was launched with a speculative algorithm (the generation path disables logprobs + in that mode). +- **Prompt/completion:** both must be non-empty (the first completion token needs + prior context to be scored). +- **Surface:** exposed as the `Engine` Python method. A native HTTP / SMG endpoint + is deferred until there is a consumer for it. diff --git a/python/tokenspeed/runtime/engine/compute_log_probs.py b/python/tokenspeed/runtime/engine/compute_log_probs.py new file mode 100644 index 000000000..a58740b6b --- /dev/null +++ b/python/tokenspeed/runtime/engine/compute_log_probs.py @@ -0,0 +1,142 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Pure, GPU-free helpers for the compute_log_probs API (RL-plan Milestone 2). + +The engine scores ``prompt + completion`` sequences by reusing the normal +generation path: a forward-only ``generate`` call with ``return_logprob=True`` +and ``logprob_start_len=len(prompt)`` makes ``meta_info['input_token_logprobs']`` +carry exactly the per-completion-token logprobs. These helpers build that call +and parse its result; ``Engine.compute_log_probs`` wires them to ``self.generate``. +""" + +from __future__ import annotations + +from typing import Any, Callable + +DEFAULT_TEMPERATURE = 1.0 +# Set to 1 if the GPU spike shows max_new_tokens=0 is unsupported; the single +# generated token lands in output_token_logprobs, never input_token_logprobs. +SCORE_MAX_NEW_TOKENS = 0 + + +class InvalidSequenceError(ValueError): + """Raised when a sequence cannot be scored (empty prompt or completion).""" + + +def validate_sequence( + prompt_token_ids: list[int], completion_token_ids: list[int] +) -> None: + if not prompt_token_ids: + raise InvalidSequenceError( + "prompt_token_ids must be non-empty: the first completion token needs " + "prior context to be scored." + ) + if not completion_token_ids: + raise InvalidSequenceError( + "completion_token_ids must be non-empty: nothing to score." + ) + + +def build_score_kwargs( + prompt_token_ids: list[int], + completion_token_ids: list[int], + temperature: float = DEFAULT_TEMPERATURE, +) -> dict[str, Any]: + """Build the kwargs for an internal forward-only ``Engine.generate`` call.""" + validate_sequence(prompt_token_ids, completion_token_ids) + # Note: compute_log_probs_core separately gates on temperature != 1.0 for v1; + # the two checks serve different audiences (standalone helper vs. v1 core path), + # so the divergence is intentional, not accidental. + if temperature <= 0: + raise ValueError(f"temperature must be > 0, got {temperature}") + return { + "input_ids": list(prompt_token_ids) + list(completion_token_ids), + "sampling_params": { + "max_new_tokens": SCORE_MAX_NEW_TOKENS, + "temperature": temperature, + }, + "return_logprob": True, + # The logprob of completion token c_j is read from the logits at the + # *preceding* position, so scoring starts one token before the + # completion: logprob_start_len = len(prompt) - 1. The engine returns + # one entry per position from there to the end — the M completion + # logprobs followed by one trailing sampled-position entry (target token + # -1) that extract_completion_logprobs drops. (Verified on B200.) + "logprob_start_len": len(prompt_token_ids) - 1, + } + + +def extract_completion_logprobs( + meta_info: dict[str, Any], num_completion: int +) -> tuple[list[float], list[int]]: + """Split ``meta_info['input_token_logprobs']`` into (log_probs, tokens). + + Each entry is a ``(logprob, token_id, text_or_None)`` tuple. The engine + returns the M completion logprobs (aligned to ``logprob_start_len = + len(prompt) - 1``) followed by one trailing sampled-position entry, so we + keep the first ``num_completion``. Fewer than that means the logprob window + was wrong (or input logprobs were not produced), so we fail loudly rather + than return a silently-misaligned array. + """ + entries = meta_info.get("input_token_logprobs") + if not entries or len(entries) < num_completion: + got = 0 if entries is None else len(entries) + raise ValueError( + f"expected at least {num_completion} completion logprobs, got {got}; " + "check logprob_start_len alignment / input-logprob support." + ) + entries = entries[:num_completion] + log_probs = [float(e[0]) for e in entries] + tokens = [int(e[1]) for e in entries] + return log_probs, tokens + + +def compute_log_probs_core( + sequences: list[dict[str, list[int]]], + generate_fn: Callable[..., dict[str, Any]], + temperature: float = DEFAULT_TEMPERATURE, +) -> dict[str, list[list[float]]]: + """Score each sequence by calling ``generate_fn`` and parsing the result. + + ``generate_fn`` must have the signature of ``Engine.generate`` and return a + single result dict (non-streaming) carrying ``meta_info``. v1 supports only + ``temperature == 1.0`` (raw log_softmax), matching the engine's default + ``temp_scaled_logprobs=False`` path; other values raise ``NotImplementedError``. + """ + if temperature != DEFAULT_TEMPERATURE: + raise NotImplementedError( + "compute_log_probs v1 supports temperature=1.0 (raw log_softmax) only; " + f"got {temperature}. Sampling-temperature scaling is a follow-up." + ) + + log_probs_out: list[list[float]] = [] + tokens_out: list[list[int]] = [] + for seq in sequences: + prompt_ids = seq["prompt_token_ids"] + completion_ids = seq["completion_token_ids"] + kwargs = build_score_kwargs(prompt_ids, completion_ids, temperature) + result = generate_fn(**kwargs) + log_probs, tokens = extract_completion_logprobs( + result["meta_info"], len(completion_ids) + ) + log_probs_out.append(log_probs) + tokens_out.append(tokens) + return {"log_probs": log_probs_out, "tokens": tokens_out} diff --git a/python/tokenspeed/runtime/engine/generation_output_processor.py b/python/tokenspeed/runtime/engine/generation_output_processor.py index e79c9f6a0..f1212a006 100644 --- a/python/tokenspeed/runtime/engine/generation_output_processor.py +++ b/python/tokenspeed/runtime/engine/generation_output_processor.py @@ -107,6 +107,12 @@ def __init__( self.output_token_logprobs_idx: list[int] | None = ( [] if return_logprob else None ) + # Input/prompt-token logprobs (populated once at prefill when requested). + self.input_token_logprobs_val: list[float] | None = ( + [] if return_logprob else None + ) + self.input_token_logprobs_idx: list[int] | None = [] if return_logprob else None + self.input_token_logprobs_sent: bool = False # --- Streaming bookkeeping (internal) --- self._surr_offset: int | None = None @@ -521,6 +527,20 @@ def post_process_forward_op( if model_execution_results.output_logprobs is not None else None ) + # Input/prompt-token logprobs for this forward (pure-extend scoring + # batches only): a flat (values, token_ids) pair over extend requests. + _input_logprobs_pair = model_execution_results.input_token_logprobs + input_logprobs_val = ( + _input_logprobs_pair[0].tolist() + if _input_logprobs_pair is not None + else None + ) + input_logprobs_idx = ( + _input_logprobs_pair[1].tolist() + if _input_logprobs_pair is not None and _input_logprobs_pair[1] is not None + else None + ) + ilp_pt = 0 pt = 0 for i, rid in enumerate(forward_op.request_ids): output_length = model_execution_results.output_lengths[i].item() @@ -538,12 +558,41 @@ def post_process_forward_op( else: pt += output_length + # Slice this request's input/prompt logprobs and advance the flat + # pointer BEFORE any `continue`, so alignment holds across requests. + req_input_lp_val = None + req_input_lp_idx = None + if input_logprobs_val is not None and i < num_extends: + sl = int(forward_op.extend_logprob_start_lens[i]) + if sl >= 0: + plen = int(forward_op.input_lengths[i]) - sl + if plen > 0: + req_input_lp_val = input_logprobs_val[ilp_pt : ilp_pt + plen] + if input_logprobs_idx is not None: + req_input_lp_idx = input_logprobs_idx[ + ilp_pt : ilp_pt + plen + ] + ilp_pt += plen + if rid not in self.rid_to_state: # means it's delayed token, do not process continue request_state: RequestState = self.rid_to_state[rid] + # Accumulate input/prompt logprobs BEFORE the chunked-prefill guard + # below: when a scored sequence spans multiple prefill chunks, each + # non-final chunk contributes part of the requested window. Skipping + # this on chunk boundaries would drop those tokens and leave + # compute_log_probs with fewer logprobs than completion tokens. + if ( + req_input_lp_val is not None + and request_state.input_token_logprobs_val is not None + ): + request_state.input_token_logprobs_val.extend(req_input_lp_val) + if req_input_lp_idx is not None: + request_state.input_token_logprobs_idx.extend(req_input_lp_idx) + # Do not output chunking result if not request_state.prefill_finished: continue @@ -706,6 +755,8 @@ def stream_output( output_extra_infos: list[dict] = [] output_token_logprobs_val: list[list[float]] = [] output_token_logprobs_idx: list[list[int]] = [] + input_token_logprobs_val: list[list[float]] = [] + input_token_logprobs_idx: list[list[int]] = [] for i, rs in enumerate(output_states): # For finished requests, always output (unless already output) @@ -785,6 +836,21 @@ def stream_output( output_token_logprobs_val.append([]) output_token_logprobs_idx.append([]) + # Input/prompt logprobs are produced once at prefill; ship them on + # the first output for this request, then mark sent so multi-token + # generations don't resend them every stream step. + if ( + rs.return_logprob + and rs.input_token_logprobs_val + and not rs.input_token_logprobs_sent + ): + input_token_logprobs_val.append(list(rs.input_token_logprobs_val)) + input_token_logprobs_idx.append(list(rs.input_token_logprobs_idx)) + rs.input_token_logprobs_sent = True + else: + input_token_logprobs_val.append([]) + input_token_logprobs_idx.append([]) + # Don't send empty batch to detokenizer if len(rids_to_send) == 0: return @@ -804,8 +870,8 @@ def stream_output( completion_tokens=completion_tokens, cached_tokens=cached_tokens, spec_verify_ct=spec_verify_ct, - input_token_logprobs_val=[], - input_token_logprobs_idx=[], + input_token_logprobs_val=input_token_logprobs_val, + input_token_logprobs_idx=input_token_logprobs_idx, output_token_logprobs_val=output_token_logprobs_val, output_token_logprobs_idx=output_token_logprobs_idx, input_top_logprobs_val=[], diff --git a/python/tokenspeed/runtime/engine/request_handler.py b/python/tokenspeed/runtime/engine/request_handler.py index 1596e08aa..6cd549923 100644 --- a/python/tokenspeed/runtime/engine/request_handler.py +++ b/python/tokenspeed/runtime/engine/request_handler.py @@ -192,9 +192,18 @@ def handle_generate_request( if recv_req.bootstrap_port is None: recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port + # Input/prompt-token logprobs are requested only when return_logprob is + # set AND a non-negative logprob_start_len is given; otherwise -1 tells + # the scheduler to skip them (output-only or no logprobs). + logprob_start_len = -1 + if recv_req.return_logprob and recv_req.logprob_start_len is not None: + if recv_req.logprob_start_len >= 0: + logprob_start_len = recv_req.logprob_start_len + req_spec = make_spec( rid=recv_req.rid, tokens=recv_req.input_ids, + logprob_start_len=logprob_start_len, ) req_state = RequestState.from_recv_req( recv_req, diff --git a/python/tokenspeed/runtime/engine/scheduler_utils.py b/python/tokenspeed/runtime/engine/scheduler_utils.py index 92394c45e..ca133a6c6 100644 --- a/python/tokenspeed/runtime/engine/scheduler_utils.py +++ b/python/tokenspeed/runtime/engine/scheduler_utils.py @@ -44,10 +44,12 @@ _TRUTHY_ENV_VALUES = {"1", "true", "yes", "on"} -def make_spec(rid: str, tokens: list[int]) -> RequestSpec: +def make_spec(rid: str, tokens: list[int], logprob_start_len: int = -1) -> RequestSpec: spec = RequestSpec() spec.request_id = rid spec.tokens = tokens + # -1 means input/prompt-token logprobs are not requested. + spec.logprob_start_len = logprob_start_len return spec diff --git a/python/tokenspeed/runtime/entrypoints/engine.py b/python/tokenspeed/runtime/entrypoints/engine.py index 83cef7b49..e8fec36ae 100755 --- a/python/tokenspeed/runtime/entrypoints/engine.py +++ b/python/tokenspeed/runtime/entrypoints/engine.py @@ -57,6 +57,7 @@ def _ignore_threading_atexit(*args, **kwargs) -> None: import torch import uvloop +from tokenspeed.runtime.engine.compute_log_probs import compute_log_probs_core from tokenspeed.runtime.engine.data_parallel_controller import ( run_data_parallel_controller_process, ) @@ -411,6 +412,26 @@ def resume_memory_occupation(self, tags: list[str] | None = None): obj = ResumeMemoryOccupationReqInput(tags=tags) return self.llm.run(self.tokenizer_manager.resume_memory_occupation(obj)) + def compute_log_probs( + self, + sequences: list[dict[str, list[int]]], + temperature: float = 1.0, + ) -> dict[str, list[list[float]]]: + """Score prompt+completion sequences under current weights (RL-plan M2). + + ``sequences`` is a list of + ``{"prompt_token_ids": [...], "completion_token_ids": [...]}``. + Returns ``{"log_probs": [[...], ...], "tokens": [[completion ids], ...]}`` + where ``log_probs[i][j] == log P(completion_token_ids[i][j] | context)`` + at temperature 1.0 (raw log_softmax). + """ + if self.tokenizer_manager.server_args.speculative_algorithm is not None: + raise RuntimeError( + "compute_log_probs is unavailable when speculative decoding is " + "enabled (Engine.generate disables logprobs in that mode)." + ) + return compute_log_probs_core(sequences, self.generate, temperature) + """ Execute an RPC call on all scheduler processes. """ diff --git a/python/tokenspeed/runtime/entrypoints/engine_base.py b/python/tokenspeed/runtime/entrypoints/engine_base.py index 960ae403f..d4ff39b06 100755 --- a/python/tokenspeed/runtime/entrypoints/engine_base.py +++ b/python/tokenspeed/runtime/entrypoints/engine_base.py @@ -79,6 +79,14 @@ def release_memory_occupation(self) -> None: def resume_memory_occupation(self) -> None: """Resume GPU memory occupation which is previously released.""" + @abstractmethod + def compute_log_probs( + self, + sequences: list[dict[str, list[int]]], + temperature: float = 1.0, + ) -> dict[str, list[list[float]]]: + """Score prompt+completion sequences and return per-completion-token logprobs.""" + @abstractmethod def shutdown(self) -> None: """Shutdown the engine and clean up resources.""" diff --git a/python/tokenspeed/runtime/execution/context.py b/python/tokenspeed/runtime/execution/context.py index b4e80dae5..d55021273 100644 --- a/python/tokenspeed/runtime/execution/context.py +++ b/python/tokenspeed/runtime/execution/context.py @@ -60,3 +60,17 @@ class ForwardContext: # --- logits processor --- gather_ids: torch.Tensor | None = None + + # --- input/prompt-token logprobs --- + # When True, the LogitsProcessor also computes per-position logprobs for the + # input (prompt+completion) tokens, not just the sampled token. + extend_return_logprob: bool = False + # Per-extend-request start position (within the extend tokens) from which to + # collect input logprobs, and the per-extend-request extend lengths. + extend_logprob_start_lens_cpu: list[int] | None = None + extend_seq_lens_cpu: list[int] | None = None + # Per-request count of kept input-logprob positions (extend_len - start_len). + extend_logprob_pruned_lens_cpu: list[int] | None = None + # Flat GPU tensor of target token ids (one per kept input position) whose + # logprob is gathered: the shifted input ids sliced to [start_len:extend_len]. + extend_input_logprob_token_ids_gpu: torch.Tensor | None = None diff --git a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py index f28497c49..b08007f20 100644 --- a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py +++ b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py @@ -453,7 +453,10 @@ def run_once(): self.capturable_grammar.reset_state() global_graph_memory_pool = graph.pool() - return graph, out + # Captured decode graphs never produce input/prompt logprobs (4th + # element of _forward_step's return); keep the output buffers a 3-tuple + # to match the replay-path unpack. + return graph, out[:3] def _capture_paged_cache_block_tables(self, bs: int, pool) -> dict | None: specs = tuple(pool.paged_cache_group_specs) @@ -911,6 +914,8 @@ def __call__( if output_logprobs is not None else None ), + # Captured (decode) graphs never produce input/prompt logprobs. + None, ) else: diff --git a/python/tokenspeed/runtime/execution/model_executor.py b/python/tokenspeed/runtime/execution/model_executor.py index 780cfafcc..6bef47140 100644 --- a/python/tokenspeed/runtime/execution/model_executor.py +++ b/python/tokenspeed/runtime/execution/model_executor.py @@ -491,7 +491,54 @@ def _forward_step( ] = next_round_input_ids.to(torch.int32) output_logprobs = logits_output.next_token_logprobs - return output_tokens, accept_lengths, output_logprobs + # Input/prompt-token logprobs (only produced in EXTEND when requested). + # Bundled as a single (val, token_ids) payload to keep the forward_step + # return arity stable across the captured-decode path (which has none). + input_logprobs = None + if logits_output.input_token_logprobs is not None: + input_logprobs = ( + logits_output.input_token_logprobs, + ctx.extend_input_logprob_token_ids_gpu, + ) + return output_tokens, accept_lengths, output_logprobs, input_logprobs + + def _maybe_set_input_logprob_ctx(self, ctx, forward_op, bs, num_extends): + """Enable input/prompt-token logprobs on the ForwardContext when every + request in a pure-extend batch requests them (the compute_log_probs + scoring path). + + Gated to pure-extend batches (num_extends == bs) so the per-request + hidden-state slicing in LogitsProcessor stays aligned — mixed batches + interleave decode rows. Target token ids are the already-shifted prefill + ids sliced to each request's [start_len:extend_len] window. + """ + if num_extends == 0 or num_extends != bs: + return + start_lens = [ + int(s) for s in forward_op.extend_logprob_start_lens[:num_extends] + ] + if not all(s >= 0 for s in start_lens): + return + seq_lens = [ + int(x) + for x in self.input_buffers.extend_seq_lens_cpu[:num_extends].tolist() + ] + pruned = [el - sl for sl, el in zip(start_lens, seq_lens)] + if any(p <= 0 for p in pruned): + return + shifted = self.input_buffers.shifted_prefill_ids_buf + idx_slices = [] + pt = 0 + for sl, el in zip(start_lens, seq_lens): + idx_slices.append(shifted[pt + sl : pt + el]) + pt += el + ctx.extend_return_logprob = True + ctx.extend_logprob_start_lens_cpu = start_lens + ctx.extend_seq_lens_cpu = seq_lens + ctx.extend_logprob_pruned_lens_cpu = pruned + ctx.extend_input_logprob_token_ids_gpu = ( + torch.cat(idx_slices) if idx_slices else None + ) @nvtx_range("update_runtime_state", color="orange") def _update_runtime_state( @@ -1217,6 +1264,8 @@ def execute_forward_op( ctx.global_bs = dp_global_bs ctx.all_decode_or_idle = dp_all_decode_or_idle + self._maybe_set_input_logprob_ctx(ctx, forward_op, bs, num_extends) + with nvtx_range("sampling_prep", color="yellow"): sampling_info = self._build_sampling_info(bs, sampling_params_list) grammar_completion = setup_grammar_step( @@ -1279,7 +1328,12 @@ def execute_forward_op( device=self.device, num_reqs=bs, ) - output_tokens, output_lengths, output_logprobs = self.forward_step( + ( + output_tokens, + output_lengths, + output_logprobs, + input_logprobs, + ) = self.forward_step( bs=bs, ctx=ctx, sampling_info=sampling_info, @@ -1349,6 +1403,17 @@ def execute_forward_op( if output_logprobs is not None: output_logprobs = output_logprobs.to("cpu", non_blocking=True) + if input_logprobs is not None: + ilp_val, ilp_idx = input_logprobs + input_logprobs = ( + ilp_val.to("cpu", non_blocking=True), + ( + ilp_idx.to("cpu", non_blocking=True) + if ilp_idx is not None + else None + ), + ) + copy_event = torch.cuda.Event() copy_event.record() @@ -1356,6 +1421,7 @@ def execute_forward_op( output_tokens=output_tokens, output_lengths=output_lengths, output_logprobs=output_logprobs, + input_token_logprobs=input_logprobs, copy_event=copy_event, grammar_completion=grammar_completion, ) diff --git a/python/tokenspeed/runtime/execution/types.py b/python/tokenspeed/runtime/execution/types.py index 9cc0eedb6..47c5294a4 100644 --- a/python/tokenspeed/runtime/execution/types.py +++ b/python/tokenspeed/runtime/execution/types.py @@ -55,6 +55,10 @@ class ModelExecutionResult: # Populated unconditionally by the sampling backend so it's always # available if any request asks for it. output_logprobs: torch.Tensor | None = None + # Input/prompt-token logprobs for pure-extend scoring batches, as a + # (values, token_ids) CPU-tensor pair flat over all extend requests, or None + # when no request in the batch requested input logprobs. + input_token_logprobs: tuple[torch.Tensor, torch.Tensor | None] | None = None def sync(self) -> None: assert self.copy_event is not None diff --git a/python/tokenspeed/runtime/layers/logits_processor.py b/python/tokenspeed/runtime/layers/logits_processor.py index aac7e6c6d..973be19b5 100755 --- a/python/tokenspeed/runtime/layers/logits_processor.py +++ b/python/tokenspeed/runtime/layers/logits_processor.py @@ -123,6 +123,11 @@ def from_forward_context( capture_hidden_mode=ctx.capture_hidden_mode, gather_ids=ctx.gather_ids, extend_seq_lens=input_lengths, + extend_return_logprob=ctx.extend_return_logprob, + extend_logprob_start_lens_cpu=ctx.extend_logprob_start_lens_cpu, + extend_seq_lens_cpu=ctx.extend_seq_lens_cpu, + extend_logprob_pruned_lens_cpu=ctx.extend_logprob_pruned_lens_cpu, + extend_input_logprob_token_ids_gpu=ctx.extend_input_logprob_token_ids_gpu, ) diff --git a/test/runtime/test_compute_log_probs.py b/test/runtime/test_compute_log_probs.py new file mode 100644 index 000000000..2dd45e35d --- /dev/null +++ b/test/runtime/test_compute_log_probs.py @@ -0,0 +1,247 @@ +"""Unit tests for compute_log_probs pure helpers (CPU, no GPU required).""" + +from __future__ import annotations + +import math +import os + +import pytest + +from tokenspeed.runtime.engine.compute_log_probs import ( + InvalidSequenceError, + build_score_kwargs, + compute_log_probs_core, + extract_completion_logprobs, + validate_sequence, +) + +# --------------------------------------------------------------------------- +# build_score_kwargs tests +# --------------------------------------------------------------------------- + + +def test_build_score_kwargs_shape(): + kw = build_score_kwargs([1, 2, 3, 4], [5, 6, 7], temperature=1.0) + assert kw["input_ids"] == [1, 2, 3, 4, 5, 6, 7] + assert kw["sampling_params"] == {"max_new_tokens": 0, "temperature": 1.0} + assert kw["return_logprob"] is True + assert kw["logprob_start_len"] == 3 # len(prompt) - 1 (score from preceding token) + + +def test_build_score_kwargs_rejects_empty_prompt(): + with pytest.raises(InvalidSequenceError): + build_score_kwargs([], [5, 6], temperature=1.0) + + +def test_build_score_kwargs_rejects_empty_completion(): + with pytest.raises(InvalidSequenceError): + build_score_kwargs([1, 2], [], temperature=1.0) + + +def test_build_score_kwargs_rejects_nonpositive_temperature(): + with pytest.raises(ValueError): + build_score_kwargs([1, 2], [3], temperature=0.0) + + +def test_build_score_kwargs_rejects_negative_temperature(): + with pytest.raises(ValueError): + build_score_kwargs([1, 2], [3], temperature=-0.5) + + +# --------------------------------------------------------------------------- +# validate_sequence direct tests +# --------------------------------------------------------------------------- + + +def test_validate_sequence_rejects_empty_prompt(): + with pytest.raises(InvalidSequenceError): + validate_sequence([], [1]) + + +def test_validate_sequence_rejects_empty_completion(): + with pytest.raises(InvalidSequenceError): + validate_sequence([1], []) + + +def test_validate_sequence_accepts_nonempty(): + # Should not raise. + validate_sequence([1], [2]) + + +# --------------------------------------------------------------------------- +# extract_completion_logprobs tests +# --------------------------------------------------------------------------- + + +def _meta(entries): + # entries: list of (logprob, token_id, text-or-None) as produced by + # LogprobsProcessor.detokenize_logprob_tokens (engine/logprobs.py:148). + return {"input_token_logprobs": entries} + + +def test_extract_completion_logprobs_happy_path(): + meta = _meta([(-0.12, 5, None), (-0.47, 6, None), (-0.31, 7, None)]) + log_probs, tokens = extract_completion_logprobs(meta, num_completion=3) + assert log_probs == [-0.12, -0.47, -0.31] + assert tokens == [5, 6, 7] + + +def test_extract_completion_logprobs_length_mismatch_raises(): + meta = _meta([(-0.12, 5, None), (-0.47, 6, None)]) + with pytest.raises(ValueError): + extract_completion_logprobs(meta, num_completion=3) + + +def test_extract_completion_logprobs_missing_key_raises(): + with pytest.raises(ValueError): + extract_completion_logprobs({}, num_completion=3) + + +# --------------------------------------------------------------------------- +# compute_log_probs_core tests +# --------------------------------------------------------------------------- + + +def _fake_generate_factory(): + """Returns a generate_fn that echoes deterministic logprobs per sequence, + so we can assert ordering and slicing without a GPU.""" + + def generate_fn(*, input_ids, sampling_params, return_logprob, logprob_start_len): + # Mirror the engine (verified on B200): input_token_logprobs[k] is the + # logprob of the token at position logprob_start_len+1+k (one-position + # shift), followed by one trailing sampled-position entry (target -1). + targets = input_ids[logprob_start_len + 1 :] + [-1] + entries = [(-0.1 * (i + 1), tok, None) for i, tok in enumerate(targets)] + return {"meta_info": {"input_token_logprobs": entries}} + + return generate_fn + + +def test_core_multi_sequence_ordering_and_slicing(): + seqs = [ + {"prompt_token_ids": [1, 2, 3], "completion_token_ids": [4, 5]}, + {"prompt_token_ids": [9], "completion_token_ids": [8, 7, 6]}, + ] + out = compute_log_probs_core(seqs, _fake_generate_factory(), temperature=1.0) + assert out["tokens"] == [[4, 5], [8, 7, 6]] + assert out["log_probs"][0] == [-0.1, -0.2] + assert out["log_probs"][1] == pytest.approx([-0.1, -0.2, -0.3]) + + +def test_core_empty_sequences_returns_empty(): + out = compute_log_probs_core([], _fake_generate_factory(), temperature=1.0) + assert out == {"log_probs": [], "tokens": []} + + +def test_core_rejects_non_unit_temperature(): + seqs = [{"prompt_token_ids": [1], "completion_token_ids": [2]}] + with pytest.raises(NotImplementedError): + compute_log_probs_core(seqs, _fake_generate_factory(), temperature=2.0) + + +# --------------------------------------------------------------------------- +# GPU integration (deferred lane). Skipped unless TOKENSPEED_RUN_GPU_TESTS=1 on a +# GPU box. Verifies the end-to-end Engine.compute_log_probs path: shape, ordering, +# and that returned values are valid log-probabilities. +# --------------------------------------------------------------------------- + +requires_gpu = pytest.mark.skipif( + os.environ.get("TOKENSPEED_RUN_GPU_TESTS") != "1", + reason="set TOKENSPEED_RUN_GPU_TESTS=1 on a GPU box to run", +) + +# Engine config for deterministic prefill-only scoring. The default `mha` +# attention backend cannot serve a mixed (prefill+decode) batch eagerly, and the +# captured-CUDA-graph decode path it falls back to rejects multi-token query +# shapes — both of which the scoring path (and chunked prefill in particular) +# hit. flashinfer + eager + no-overlap keeps scoring on a pure-extend path that +# every backend handles. Validated on B200 (Qwen2-1.5B-Instruct). Overridable via +# TOKENSPEED_TEST_ATTN_BACKEND for other GPUs/models. +_SCORING_ENGINE_KWARGS = { + "attention_backend": os.environ.get("TOKENSPEED_TEST_ATTN_BACKEND", "flashinfer"), + "enforce_eager": True, + "disable_overlap_schedule": True, + "log_level": "error", +} + + +@requires_gpu +def test_compute_log_probs_end_to_end(): + from tokenspeed.runtime.entrypoints.engine import Engine + + engine = Engine( + model=os.environ.get("TOKENSPEED_TEST_MODEL", "Qwen/Qwen2.5-0.5B-Instruct"), + **_SCORING_ENGINE_KWARGS, + ) + try: + seqs = [ + {"prompt_token_ids": [1, 2, 3, 4], "completion_token_ids": [5, 6, 7]}, + {"prompt_token_ids": [10, 11], "completion_token_ids": [12]}, + ] + out = engine.compute_log_probs(seqs, temperature=1.0) + assert out["tokens"] == [[5, 6, 7], [12]] + assert len(out["log_probs"]) == 2 + assert len(out["log_probs"][0]) == 3 and len(out["log_probs"][1]) == 1 + for row in out["log_probs"]: + for lp in row: + assert lp <= 0.0 and math.isfinite(lp) # valid log-probabilities + finally: + engine.shutdown() + + +@requires_gpu +def test_compute_log_probs_long_sequence_chunked_matches_single_chunk(): + """Regression: scoring a sequence whose logprob window spans >1 prefill chunk. + + ``logprob_start_len = len(prompt) - 1``, so the scored window covers the + last prompt token plus every completion token. When a long prompt+completion + is split across prefill chunks, that window straddles a chunk boundary. The + original wiring dropped every chunk after the first in-window one — the C++ + scheduler set ``extend_logprob_start_len = -1`` for ``rel < 0`` and the + output processor discarded non-final prefill chunks — so long sequences + returned too few logprobs (ValueError) instead of scoring. + + We pin the prompt length to a multiple of a small ``chunked_prefill_size`` so + the completion begins exactly on a chunk boundary (guaranteeing the window + crosses it), then assert the chunked result matches the single-chunk path + (the configuration validated on B200) within bf16 reduction-order tolerance. + """ + from tokenspeed.runtime.entrypoints.engine import Engine + + model = os.environ.get("TOKENSPEED_TEST_MODEL", "Qwen/Qwen2.5-0.5B-Instruct") + chunk = 128 + # 512 = 4 * chunk -> completion starts on a chunk boundary; window crosses it. + prompt = [10 + (i % 1000) for i in range(4 * chunk)] + completion = [200 + i for i in range(40)] + seqs = [{"prompt_token_ids": prompt, "completion_token_ids": completion}] + + # Reference: large chunk size -> the whole sequence prefills in one chunk. + ref_engine = Engine( + model=model, chunked_prefill_size=8192, **_SCORING_ENGINE_KWARGS + ) + try: + ref = ref_engine.compute_log_probs(seqs, temperature=1.0) + finally: + ref_engine.shutdown() + + # Chunked: small chunk size -> the window spans multiple prefill chunks. + chunked_engine = Engine( + model=model, chunked_prefill_size=chunk, **_SCORING_ENGINE_KWARGS + ) + try: + got = chunked_engine.compute_log_probs(seqs, temperature=1.0) + finally: + chunked_engine.shutdown() + + # Correct count (the original bug raised here) and valid values. + assert got["tokens"] == [completion] + assert len(got["log_probs"][0]) == len(completion) + for lp in got["log_probs"][0]: + assert lp <= 0.0 and math.isfinite(lp) + + # Chunking the prefill must not change the scores. + assert len(ref["log_probs"][0]) == len(completion) + max_abs_diff = max( + abs(a - b) for a, b in zip(ref["log_probs"][0], got["log_probs"][0]) + ) + assert max_abs_diff < 0.05, f"chunked vs single-chunk diverged: {max_abs_diff}" diff --git a/tokenspeed-scheduler/bindings/python_module.cpp b/tokenspeed-scheduler/bindings/python_module.cpp index eaa825b29..37c88a48e 100644 --- a/tokenspeed-scheduler/bindings/python_module.cpp +++ b/tokenspeed-scheduler/bindings/python_module.cpp @@ -248,7 +248,8 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { .def_rw("request_id", &tokenspeed::RequestSpec::request_id) .def_rw("tokens", &tokenspeed::RequestSpec::tokens) .def_rw("rolling_hashes", &tokenspeed::RequestSpec::rolling_hashes) - .def_rw("storage_hit_pages", &tokenspeed::RequestSpec::storage_hit_pages); + .def_rw("storage_hit_pages", &tokenspeed::RequestSpec::storage_hit_pages) + .def_rw("logprob_start_len", &tokenspeed::RequestSpec::logprob_start_len); nb::module_ forward_event = m.def_submodule("ForwardEvent"); nb::class_(forward_event, "ExtendResult") @@ -320,6 +321,7 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { flat_fwd_op.def_ro("input_ids", &tokenspeed::FlatForwardOperation::input_ids) .def_ro("shifted_input_ids", &tokenspeed::FlatForwardOperation::shifted_input_ids) .def_ro("extend_prefix_lens", &tokenspeed::FlatForwardOperation::extend_prefix_lens) + .def_ro("extend_logprob_start_lens", &tokenspeed::FlatForwardOperation::extend_logprob_start_lens) .def_prop_ro( "prefill_lengths", [](const tokenspeed::FlatForwardOperation& op) -> const std::vector& { diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp index fd48962fe..b7c5b7089 100644 --- a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp @@ -414,6 +414,27 @@ static PrefillOperation applyPrefillEvent(Request* request, Event event) { op.shifted_input_ids = std::move(info.shifted_input_ids); op.extend_prefix_len = info.already_scheduled_len; + // Rebase the prompt-relative logprob_start_len onto this prefill chunk's + // extend tokens. -1 ⇒ input logprobs not requested, or the requested region + // lies entirely after this chunk. + // + // When the requested window spans multiple prefill chunks (long + // prompt+completion), three cases arise per chunk: + // * rel >= extend_len : window starts after this chunk -> skip (-1). + // * 0 <= rel < extend_len : window starts inside this chunk -> collect + // from `rel`. + // * rel < 0 : the window opened in an earlier chunk and is still active, + // so this whole chunk is inside the window -> collect from offset 0. + // (Previously this set -1 and silently dropped every token after the + // first in-window chunk.) + const std::int32_t logprob_start_len = request->LogprobStartLen(); + if (logprob_start_len < 0) { + op.extend_logprob_start_len = -1; + } else { + const std::int32_t rel = logprob_start_len - info.already_scheduled_len; + op.extend_logprob_start_len = (rel < info.extend_len) ? std::max(rel, 0) : -1; + } + auto* mamba = request->GetLocalMambaAllocator(); if (mamba != nullptr && mamba->HasWorking()) { op.mamba_working_idx = mamba->WorkingIndex(); diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/forward.h b/tokenspeed-scheduler/csrc/scheduler/operations/forward.h index 78c606361..14560ce5d 100644 --- a/tokenspeed-scheduler/csrc/scheduler/operations/forward.h +++ b/tokenspeed-scheduler/csrc/scheduler/operations/forward.h @@ -65,6 +65,9 @@ struct PrefillOperation : public ForwardOperationBase { std::vector input_ids; std::vector shifted_input_ids; std::int32_t extend_prefix_len; + // Start position (relative to this prefill chunk's extend tokens) from which + // input/prompt-token logprobs should be collected. -1 means not requested. + std::int32_t extend_logprob_start_len{-1}; }; struct DecodeOperation : public ForwardOperationBase { @@ -89,6 +92,9 @@ struct FlatForwardOperation { std::vector input_ids; std::vector shifted_input_ids; std::vector extend_prefix_lens; + // Per-prefill (parallel to extend_prefix_lens): start position within the + // extend tokens from which to collect input logprobs. -1 = not requested. + std::vector extend_logprob_start_lens; std::vector decode_input_ids; std::vector hist_token_lens; @@ -139,6 +145,7 @@ struct FlatForwardOperation { shifted_input_ids.insert(shifted_input_ids.end(), prefill->shifted_input_ids.begin(), prefill->shifted_input_ids.end()); extend_prefix_lens.push_back(prefill->extend_prefix_len); + extend_logprob_start_lens.push_back(prefill->extend_logprob_start_len); } else if (auto* decode = std::get_if(&op)) { decode_input_ids.push_back(decode->decode_input_id); hist_token_lens.push_back(decode->hist_token_len); diff --git a/tokenspeed-scheduler/csrc/scheduler/request.cpp b/tokenspeed-scheduler/csrc/scheduler/request.cpp index 6aaa3c55a..0c59992df 100644 --- a/tokenspeed-scheduler/csrc/scheduler/request.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/request.cpp @@ -33,7 +33,8 @@ Request::Request(const RequestSpec& spec, std::int32_t page_size, Role role) page_size_{page_size}, state_{role == Role::kFused ? fsm::State{fsm::Submitted{&token_container_, page_size}} : fsm::State{fsm::Bootstrapping{&token_container_, page_size}}}, - storage_info_{spec.rolling_hashes, spec.storage_hit_pages} {} + storage_info_{spec.rolling_hashes, spec.storage_hit_pages}, + logprob_start_len_{spec.logprob_start_len} {} PrefillInfo Request::GetPrefillInfo() const { return std::visit(Overloaded{ diff --git a/tokenspeed-scheduler/csrc/scheduler/request.h b/tokenspeed-scheduler/csrc/scheduler/request.h index 89b770c68..39c3e8f29 100644 --- a/tokenspeed-scheduler/csrc/scheduler/request.h +++ b/tokenspeed-scheduler/csrc/scheduler/request.h @@ -54,6 +54,10 @@ class Request { std::string Id() const { return id_; } + // Full-sequence position from which input/prompt-token logprobs are + // requested. -1 means not requested. + std::int32_t LogprobStartLen() const { return logprob_start_len_; } + // Keep Apply the only non-const function in Request // The wrapper lambda converts any concrete state type returned by event's operator() // into fsm::State, allowing operator() to return specific state types instead of State. @@ -277,6 +281,7 @@ class Request { std::int32_t page_size_; fsm::State state_; StorageInfo storage_info_; + std::int32_t logprob_start_len_{-1}; }; using ConstRequestVector = std::vector; diff --git a/tokenspeed-scheduler/csrc/scheduler/request_spec.h b/tokenspeed-scheduler/csrc/scheduler/request_spec.h index eaf85ebda..7b3aac0e8 100644 --- a/tokenspeed-scheduler/csrc/scheduler/request_spec.h +++ b/tokenspeed-scheduler/csrc/scheduler/request_spec.h @@ -32,6 +32,10 @@ struct RequestSpec { std::vector tokens; std::vector rolling_hashes; std::int32_t storage_hit_pages{0}; + // Prompt position (in full-sequence coordinates) from which to return + // input/prompt-token logprobs. -1 (default) means input logprobs are not + // requested; only sampled-output logprobs (if any) are produced. + std::int32_t logprob_start_len{-1}; }; struct PrefillInfo {