Skip to content

Add async scheduling for dynamic inference#4558

Draft
lmcafee-nvidia wants to merge 29 commits intoNVIDIA:mainfrom
lmcafee-nvidia:context-cpu-async-schedule-20260429
Draft

Add async scheduling for dynamic inference#4558
lmcafee-nvidia wants to merge 29 commits intoNVIDIA:mainfrom
lmcafee-nvidia:context-cpu-async-schedule-20260429

Conversation

@lmcafee-nvidia
Copy link
Copy Markdown
Contributor

What does this PR do ?

Adds async scheduling for dynamic inference so decode step N+1 can be launched while CPU bookkeeping for step N runs, including the context bookkeeping groundwork needed to keep the GPU fed from CPU-owned metadata.

This PR includes:

  • CPU-side context bookkeeping and unified GPU transfer buffers for dynamic inference metadata.
  • Async scheduling configuration and engine gating predicates.
  • Split launch/bookkeep decode flow for launch-ahead scheduling.
  • Speculative context advance/restore, rejection handling, RNG rollback, and async D2H sample transfer.
  • Hybrid/Mamba SSM state save/restore support for speculative launch rollback.
  • Unit coverage for dynamic context speculative advance and related inference engine behavior.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Validation

From commit history:

  • tests/unit_tests/inference/engines/test_dynamic_engine.py: 140 passed in repeated local runs while developing the async scheduling chain.
  • tests/unit_tests/inference/contexts/test_dynamic_context.py: speculative advance/restore coverage added; known environment-dependent 4-GPU failures remained consistent with baseline during development.
  • tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py: 12 passed during hybrid/Mamba rollback validation.

lmcafee-nvidia and others added 29 commits April 16, 2026 14:42
This is the "baseline" commit referenced in all future timing discussions
for the context-cpu work.  It branches directly from main and adds NVTX
ranges for the 5 inference loop stages that exist on main:

  - initialize_attention_state
  - forward_pass
  - sampling
  - active_request_mask
  - update_requests

All ranges nest inside the existing "Prefill"/"Decode" range from
dynamic_engine.py, enabling nsys analysis of per-stage timing.

Subsequent commits on this branch will add context-cpu-specific
optimizations AND extend the NVTX range set with 2 transfer stages
(transfer_bookkeeping_to_gpu, transfer_samples_to_cpu) that don't
exist on main.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Move all per-request and per-token bookkeeping tensors in
DynamicInferenceContext from GPU to pinned CPU memory. Introduce
ContextGPUView as the single GPU interface for forward-pass code:
context.foo is always CPU (source of truth), context.gpu_view.foo
is always GPU (snapshot populated per-step by transfer_bookkeeping_to_gpu).

This eliminates CPU-GPU device mixing by establishing a clear
architectural boundary -- GPU code reads from gpu_view, CPU bookkeeping
reads from context directly. The gpu_view is populated once per step
with non-blocking pinned-memory copies.

Key changes:
- New gpu_view.py with ContextGPUView (6 token-level + 3 request-level
  GPU staging tensors)
- All request/token tensors in dynamic_context.py moved to CPU with
  pin_memory=True
- transfer_bookkeeping_to_gpu() populates gpu_view each step
- text_generation_controller.py reads gpu_view for GPU-phase ops
  (sampling, verification, log-probs)
- Post-rewind code reads CPU context directly (not stale gpu_view)
- mamba_slot_allocator.py fixed for CPU bookkeeping indexing

302 tests pass, 0 failures (90 pre-existing skips).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Defer Mamba state zeroing and prefix cache restore from add_request()
and update_requests() to transfer_bookkeeping_to_gpu(), making both
methods 100% CPU for hybrid models. Mamba compute_and_store_offsets
remains immediate since commit_intermediate_states depends on its
CPU-side state.

Add _transfer_samples_to_cpu() to make the D2H boundary explicit.

302 tests pass (each suite run separately), 0 regressions.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Launch Mamba state zeroing/restore at the start of
initialize_attention_state() instead of transfer_bookkeeping_to_gpu().
This allows the GPU ops to overlap with the CPU work that follows
(batch dimension computation, MHA metadata, token padding).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Split mamba_metadata.update() into compute_cpu_metadata() (CPU, called
in initialize_attention_state) and load_from_cpu() (H2D copies, called
in transfer_bookkeeping_to_gpu). This eliminates GPU kernel launches
for batch indices, cu_seqlens, chunk boundaries, and conv1d metadata.
The intermediate state extraction still uses GPU (_update_intermediate_metadata).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Move _intermediate_offsets_gpu, _intermediate_counts_gpu,
_intermediate_block_ids_gpu, _eos_cache_block_id_gpu from GPU to CPU.
Block IDs and EOS block IDs are pure CPU bookkeeping (consumed by
commit_intermediate_states via .tolist()). Offsets and counts keep a
GPU buffer for _update_intermediate_metadata to consume; the H2D copy
is handled by that method on first use.

Eliminates ~5 GPU writes per add_request and 2 .tolist() D2H syncs
per commit_intermediate_states.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…init split

PR NVIDIA#4225 extracted argument parsing out of initialize_megatron(); call
parse_and_validate_args() separately and invoke initialize_megatron() with
no arguments.
Drop the 5 per-step GPU kernel launches in MHAMetadata.reset()
(query_lengths, cu_query_seq_lengths, cu_kv_seq_lengths, kv_seq_lengths,
block_table). The next update() / load_from_cpu() fully overwrites the
slice of each buffer that the forward pass will read (via state_data[:n]),
so clearing here is redundant paranoia from the old GPU-resident design.

Removes 10 vectorized_elementwise_kernel launches per step
(5 buffers x Graphed + NonGraphed metadata). See
lawrence/reports/20260417-bookkeeping-gpu-ops.md, Section 1.
adjust_batch_dims_for_expert_parallelism() runs on every inference step
to pick a CUDA graph batch dimension consistent across EP ranks. It
performed a torch.distributed.all_reduce(MAX) on a 4-int GPU tensor
sandwiched between an H2D copy (tensor construction) and a D2H copy
(.cpu() to read the result on the host).

Add a sync_all_reduce_max() method to AsyncZMQCommunicator that uses
blocking ZMQ send/recv on the CPU. When the engine has created the EP
ZMQ communicator, it is attached to the context via
DynamicInferenceContext.set_ep_zmq_communicator(), which in turn is
forwarded to match_graph_config() / adjust_batch_dims_for_expert_parallelism().
When present, the 4-int MAX is done on CPU with zero GPU kernels.
The torch.distributed fallback path is kept for standalone / non-engine
call sites.

Removes one NCCL AllReduce kernel plus one H2D and one D2H per step
(~102 us/step in the 2304-step nanov3 trace). Also removes the
stream-ordering barrier that the NCCL kernel introduced on the compute
stream. See lawrence/reports/20260417-bookkeeping-gpu-ops.md, Section 3.1.
The 9 .copy_(non_blocking=True) calls in DynamicInferenceContext
.transfer_bookkeeping_to_gpu() each incur ~15us of cudaMemcpyAsync launch
overhead for ~1us of actual transfer — the NVTX range is 270us of wallclock
with ~6% GPU utilization in the nanov3 trace.

Back all 9 bookkeeping fields (6 per-token + 3 per-request) with one
contiguous uint8 buffer on each side (pinned CPU + device GPU), with each
attribute as a dtype-correct view onto a slice of the buffer. Layout is
shared between DynamicInferenceContext._cpu_bookkeeping_buf and
ContextGPUView._buf; int64 fields are placed first so alignment is
automatic. The per-step transfer is now a single cudaMemcpyAsync of
32*max_tokens + 12*max_requests bytes (~71 KB for the benchmark config).

Per-token fields are aliased with the source-of-truth attributes because
the CPU-side bookkeeping and the GPU forward pass both use the same
[:n_tok] slice. Per-request fields have an extra staging area in the
coalesced buffer, refreshed each step from the persistent CPU tensors
(at [paused_request_count:total_request_count]) into [:n_active], since
the forward pass reads them at [:n_active] on GPU but CPU bookkeeping
keeps paused requests in [0:paused_request_count).

Correctness verified: test_reset, test_update_request, test_add_request,
test_initialize_dynamic_context all pass (8 tests × 4 ranks).
MHAMetadata no longer owns private GPU buffers; GraphedMHAMetadata and
NonGraphedMHAMetadata bind to shared views of ContextGPUView._buf (only
one is active per step, so sharing storage is safe). initialize_attention_state
writes the 5 MHA fields directly into pinned slots in _cpu_bookkeeping_buf,
and transfer_bookkeeping_to_gpu's single cudaMemcpyAsync now covers them
along with the existing token/request fields -- eliminating the 5 per-step
mha.load_from_cpu copies. Per-step state_data is rebuilt via set_state_data
using the freshly transferred GPU views plus Python-int max_seqlen scalars.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Extends _cpu_bookkeeping_buf / ContextGPUView._buf with a Mamba section
(9 int32 fields, hybrid-only) that mirrors MambaMetadata's per-step varlen
tensors. MambaMetadata.compute_cpu_metadata() now writes directly into the
bound pinned CPU views instead of allocating ephemeral tensors, and
load_from_cpu() drops all 9 .copy_() calls -- the coalesced H2D in
transfer_bookkeeping_to_gpu() covers the transfer, leaving load_from_cpu()
to just alias state attributes onto the freshly-transferred GPU views and
run the intermediate-extraction GPU computation.

The legacy MambaMetadata.update() path is preserved (it still owns the
standalone *_buffer tensors for unit tests that construct MambaMetadata
without a context); it's unused on the inference path, so the ~40KB of
redundant GPU memory is negligible.

Also wires mamba_chunk_size through _allocate_mamba_states so the
MambaMetadata's internal chunk_size matches the unified-buffer sizing in
ContextGPUView.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The ~550 µs decode-step gap between update_requests and
initialize_attention_state is dominated by step_end_event.synchronize()
(CPU waits for GPU to drain so elapsed_time() can be read) plus the
pre/post context_state dict builds. All of that work feeds only the
print block (dynamic_engine.py:1858) and the W&B metrics block
(dynamic_engine.py:1818), both already gated on
  logging_step_interval > 0 and step_count % logging_step_interval == 0.

Predict that same condition once at the top of async_forward as
`will_log_this_step` and skip the logging-only work on non-logging steps:

  - step_start/end events and elapsed_time (step_time = 0.0)
  - pre_step_context_state print-only fields (keep active_token_count
    and step_count, used by post_process_requests' pre_fwd_* args)
  - kvcache_util_stats computation
  - post_step_context_state dict (and drop the two dead fields
    padded_active_token_count, using_cuda_graph_this_step that no
    consumer reads)
  - the pre/post merge (minimal dict keeps kv_stats=None so the
    metrics-block gate at 1818 stays well-typed)

In post_process_requests, gate the TPOT update on step_time > 0 so
non-logging steps don't pollute request.tpot with zeros -- the metric
becomes a sparse sample aligned with the same cadence as logging.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
_dynamic_step_context_bookkeeping already produces sampled_tokens_cpu
via _transfer_samples_to_cpu(), but the outer dict built by
async_generate_output_tokens_dynamic_batch was handing out a fresh clone
of the GPU _sampled_tokens_cuda buffer. The engine's post_process_requests
then called sample.tolist() on that GPU tensor, forcing a D2H sync --
pure overhead inside the update_requests -> initialize_attention_state
critical-path gap.

Propagate the already-allocated CPU tensor instead: add "sample":
sampled_tokens_cpu to the _dynamic_step_context_bookkeeping return dict,
drop the outer GPU clone, and keep skip_bookkeeping=True behaving by
doing a one-shot .cpu() on that path. The CPU tensor is independent
storage (fresh .cpu() allocation, not a view) and isn't mutated by the
step -- update_requests only touches new_sample_copy, a separate clone.
Net: sample.tolist() becomes pure CPU-to-list, no sync.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PR NVIDIA#3642 switched the engine-side NVTX labels from the direct
torch.cuda.nvtx API onto Megatron's gated nvtx_range_push helper. Training
flips the switch under --profile --nvtx-ranges; the inference server never
did, so bookkeeping/Decode/_ep_establish_consensus/etc. became silent
no-ops and the inter-step gap is blank in nsys. Mirror training's gating
in the server entry so the same flags enable the labels for inference.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
For async scheduling: after sampling on a pure-decode, non-speculative
step, copy the sampled tokens directly into the GPU buffer that the next
forward's CUDA graph reads as input_ids. The existing CPU+H2D path still
runs and overwrites this slot in transfer_bookkeeping_to_gpu, so behavior
is byte-identical to baseline. This commit only establishes the
GPU-resident write pattern; later commits use it to skip the CPU round
trip on the speculation hot path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
predict_decode_blocks_needed(advance=1): returns total new KV blocks the
active batch would need if every active request advanced by `advance`
decode tokens. Uses request_kv_length_offsets + request_query_lengths to
detect block-boundary crossings.

can_speculate_decode_step(advance=1): predicate combining is_decode_only,
predict_decode_blocks_needed, and the allocator's is_memory_available.
The engine in C3 calls this each step to decide whether to issue a
speculative forward or fall back to non-speculative.

Standalone-correct: methods are not called from any production path yet;
adding them does not change behavior. C3 will wire them into the
async-scheduling launcher.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Configuration plumbing for the async-scheduling speculative chain:

- arguments.py: --inference-dynamic-batching-async-scheduling (default True)
  and --inference-dynamic-batching-finished-sync-period (default 32).
- InferenceConfig: async_scheduling and finished_sync_period fields with
  the same defaults.
- get_dynamic_inference_config: thread args through to InferenceConfig.
- DynamicInferenceEngine: read both fields off the config.
- DynamicInferenceEngine._async_scheduling_active(): predicate that
  combines the config flag with context.can_speculate_decode_step()
  from C2. Refined in C4 to also handle rejection events (add/pause/
  evict) and the periodic finished-pending sync.

Standalone-correct: the predicate is not yet called from any production
loop, and the config flag has no effect on behavior. The launch-ahead
refactor that consumes this predicate is split into a follow-up commit
to keep the engine-loop changes reviewable.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Structural refactor: async_step now routes based on
_async_scheduling_active() between two private methods:

- _async_step_serial: today's behavior (await forward, await bookkeep).
- _async_step_speculative: stub that mirrors the serial path.

C4 fills in the speculative path with the launch-ahead logic and
rejection handling. Splitting the routing into its own commit keeps
the engine-loop changes reviewable and bisectable.

No behavior change: both paths are identical, so this commit is
byte-equivalent to the prior async_step implementation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Extends _async_scheduling_active with the two backstops that protect the
speculative chain from invalid state:

- _rejection_pending: set by CPU bookkeeping when an add/pause/evict
  event is detected. Read by _async_scheduling_active to drop out of
  speculation; cleared via _clear_rejection() after the engine drains
  in-flight forwards and applies the corrected state.
- _finished_pending_count + _finished_sync_due(): periodic sync
  triggered every finished_sync_period steps when finished requests
  are accumulated, so finished slots can't keep consuming compute
  indefinitely in the speculative chain.

New helpers _signal_rejection() and _clear_rejection() expose the
rejection lifecycle to the bookkeeping path.

Note: this commit adds the *framework* for rejection handling. The
launch-ahead engine-loop refactor (issuing forward N+1's launch during
step N's forward) and the drain-and-rewind logic are deferred to a
follow-up commit. The speculative-path stub from C3.5 is unchanged
here -- it routes through _async_scheduling_active correctly, but
both serial and speculative paths still execute forward then bookkeep
in sequence.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Refactor async_generate_output_tokens_dynamic_batch into two methods:
_launch_decode_step queues all GPU work (H2D + forward + sample) and
_bookkeep_decode_step performs the post-sample CPU work (D2H +
update_requests). The public method becomes a thin orchestrator with
identical behavior.

This is the structural prerequisite for C5.2: async scheduling needs
to invoke the launch path for step N+1 before running the bookkeep
path for step N, so that the GPU stays continuously busy while CPU
bookkeeping runs in parallel.

Behavior change: none. All existing tests pass byte-identical.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add the context-side primitives the engine needs to pre-launch step N+1
ahead of step N's bookkeep:

- speculatively_advance_for_next_decode_step: bumps request_kv_length_offsets
  and pinned token-level fields to what update_requests would produce
  after a no-event pure-decode step. The advance is reversible.
- restore_after_speculative_advance: reverses the persistent-tensor part
  of the advance. Pinned fields are intentionally left advanced because
  step N's update_requests will overwrite them with canonical values.
- Tightened can_speculate_decode_step: returns False whenever a request
  would cross a block boundary or num_speculative_tokens > 0, deferring
  the engine to its serial path in those cases. The speculative advance
  cannot replay block-allocation safely.

Engine state: _pending_launch_state field added (initialized to None);
will hold step N+1's launch_state across iterations once the
launch-ahead loop is wired in C5.3.

Tests:
- test_speculatively_advance_and_restore_roundtrip — verifies kv_offsets,
  block offsets, and pinned token_to_pos_ids advance and restore correctly.
- test_can_speculate_decode_step_rejects_boundary_crossing — verifies
  the gating predicate fires regardless of allocator availability.

No behavior change yet: _async_step_speculative remains a stub (still
identical to serial). C5.3 wires the launch-ahead loop and the
synchronization needed to read sample_N before sample_{N+1} clobbers
the GPU buffer.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Lays in the engine and controller plumbing required by the launch-ahead
loop without yet wiring it. Behavior unchanged: _async_step_speculative
remains a stub identical to the serial path.

Engine changes:
- _pre_fetched_sample_buf: pinned CPU int64 buffer of size max_requests.
  The launch-ahead loop will D2H sample_N here once before queueing the
  speculative step N+1 launch (otherwise sample_{N+1} would clobber the
  shared _sampled_tokens_cuda buffer before update_requests reads it).
- _d2h_done_event / _h2d_done_event: pre-allocated CUDA events for
  gating CPU work on stream completion.
- _drain_pending_launch: synchronizes the stream and discards a stale
  speculative pre-launch on rejection. Wired into _async_step_serial.
- _async_step_serial now also calls _clear_rejection at entry so
  rejected speculative chains hand back a clean state to the serial
  path before async_forward runs.

Controller changes:
- _bookkeep_decode_step accepts sample_cpu_override; threads through to
  _dynamic_step_context_bookkeeping. When provided, the override skips
  the in-bookkeep D2H of _sampled_tokens_cuda. The launch-ahead loop
  will use this so step N's bookkeep reads the sample_N value the engine
  D2H'd before queueing step N+1's sample.

A naive launch-ahead wiring (sync sample_N, mirror it into pinned
token_to_input_ids, speculatively advance, pre-launch step N+1, run
update_requests in parallel) was prototyped on this branch and produced
divergent tokens vs. the serial baseline starting at the first decode
iteration following a mid-chain rejection. Diagnosis pointed at state
pollution from the discarded pre-launch (KV writes / attention metadata)
that the persistent-tensor restore alone does not reverse. Resolving it
cleanly requires either a shadow bookkeeping buffer or a more thorough
rollback path; deferred so the infrastructure can land separately and
keep this commit reviewable.

Verification: full test_dynamic_engine.py suite (140 passed, 94 skipped)
and existing test_dynamic_context.py speculative-advance round-trip
tests pass. No regressions vs. parent (3ebe644).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Wires up the actual launch-ahead engine loop. On pure-decode chains the
GPU stream is kept continuously busy with H2D -> forward -> sample ->
H2D -> forward -> sample, while CPU update_requests for step N runs in
parallel with GPU step N+1.

Engine (`_async_step_speculative`):
- Acquire step N's launch state — consume the prior iteration's pre-
  launch, or launch fresh on warmup / post-rejection.
- Synchronous .cpu() of _sampled_tokens_cuda blocks until sample_N
  completes; GPU has been busy with forward_N -> sample_N during the
  wait so this introduces no GPU-side gap.
- Snapshot `controller.sampling_rng` state before issuing the speculative
  pre-launch's sample so the RNG can be rolled back if the pre-launch
  is later discarded.
- Mirror sample_N into pinned `token_to_input_ids` so step N+1's H2D
  copies the correct input to gpu_view.
- `speculatively_advance_for_next_decode_step` then `_launch_decode_step`
  queues H2D, forward, and sample for step N+1 on the stream. Restore
  cancels the persistent-tensor advance afterwards (pinned token-level
  fields are intentionally left at the speculative values; step N's
  update_requests overwrites them with canonical values during bookkeep).
- `_bookkeep_decode_step` consumes `sample_cpu_override` so update_requests
  reads the pre-fetched sample_N rather than racing the next sample.
- After bookkeep, finished/paused/evicted detection signals a rejection
  so the next iteration drops to the serial path.

Engine (`_async_step_serial`):
- Drain a stale pre-launch (sync stream, restore RNG, clear rejection)
  only when one is actually pending — keeps the serial path byte-
  identical to the pre-async-scheduling code path when no rejection
  occurred.

Engine (`_async_scheduling_active`):
- Adds `len(waiting_request_ids) > 0` gate. New requests change batch
  composition at the next iteration; the speculative pre-launch was
  computed for the old composition and would be invalidated. Drop to
  serial so schedule_waiting_requests can integrate the new requests
  before launch.

Context (`can_speculate_decode_step`):
- Gates speculation off for hybrid (Mamba) models. Mamba SSM forward
  advances per-request `mamba_conv_states` / `mamba_ssm_states` in
  place; speculative rollback would require saving/restoring those
  GPU buffers per active request. Implementing that is left for a
  follow-up.

State-mutation audit findings driving the design:
- Sampling: `torch.multinomial(generator=self.sampling_rng)` advances
  the CUDA generator's state. A discarded pre-launch's sample leaves
  the RNG advanced past where the baseline serial path would be at
  the rejection step, producing a different sample on re-launch. Fixed
  by snapshotting the state before the speculative sample and
  restoring it on drain.
- Mamba SSM: `mamba_conv_states` and `mamba_ssm_states` are mutated
  in place by forward. Re-running forward on the same input advances
  the state machine, producing a different output. Gated off for now.
- KV cache (attention only): writes are deterministic given same Q/K/V
  inputs at the same position; idempotent under re-run.
- Pinned bookkeeping buffer: `update_requests` overwrites the same
  fields the speculative advance touched, so pinned state self-heals
  on each step boundary.

Verification:
- 3x runs of `tests/unit_tests/inference/engines/test_dynamic_engine.py`:
  140 passed each time, no flakiness.
- The 3 mamba `test_chunked_prefill_cuda_graphs` failures observed
  during prototyping are pre-existing flakiness (verified by 3x runs
  on the parent commit: 1 of 3 also showed the same 3 failures, the
  other 2 were clean).
- `tests/unit_tests/inference/contexts/attention_metadata/
  test_mamba_metadata.py`: 12 passed.
- `tests/unit_tests/inference/contexts/test_dynamic_context.py`:
  same 21 baseline 4-GPU env failures + 53 passed (including the
  C5.2 speculative-advance round-trip and gating tests).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Extend the speculative launch-ahead loop to hybrid (Mamba) models. The
Mamba SSM forward advances mamba_conv_states and mamba_ssm_states in
place per active-request slot; unlike attention's KV writes it is not
idempotent under re-run. This commit snapshots the active slices around
the speculative pre-launch and rolls them back on rejection.

Context changes (`dynamic_context.py`):
- Allocate `_spec_mamba_conv_shadow` and `_spec_mamba_ssm_shadow` (same
  shape as the live state pools) at construction time for hybrid models.
- Add `save_mamba_state_for_speculation`: scatter-copies the active
  request slices from live mamba state into the shadow, queued on the
  stream before the speculative forward.
- Add `restore_mamba_state_for_speculation`: scatters the saved slices
  back; called from the engine's drain helper after the stream has
  been synchronized.
- Add `drop_mamba_state_speculation`: discards the shadow without
  restoring; called when the pre-launch is consumed normally.
- Remove the `is_hybrid_model: return False` gate in
  `can_speculate_decode_step` — speculation now applies to hybrid
  models too.

Engine changes (`dynamic_engine.py`):
- `_async_step_speculative` calls `save_mamba_state_for_speculation`
  alongside the RNG snapshot before the speculative `_launch_decode_step`.
- On consume (pending pre-launch becomes the next iteration's bookkeep)
  we `drop_mamba_state_speculation` along with clearing the RNG snapshot.
- `_drain_pending_launch` calls `restore_mamba_state_for_speculation`
  after the stream synchronize so the serial re-launch's forward reads
  pre-spec state.

Memory cost: doubles the mamba state pool footprint (one shadow each
for conv and ssm). Acceptable for this engine's typical 70+GB GPU
budget; can be optimized later to a packed buffer sized for active
requests if needed.

Verification: tests/unit_tests/inference/engines/test_dynamic_engine.py
passes 140/140 with C5.5 active (mamba speculation no longer gated off).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
C5.5 used PyTorch fancy indexing to save just the active mamba slot
slices into the shadow buffer. The C5.5 nanov3 nsys trace showed
Decode iteration time regressing from 7828 µs (C5.4 baseline, mamba
gated off) to 9037 µs — a 1209 µs / 15% increase per iteration. With
speculation firing in roughly 96 of 1535 decode iterations (~6%), that
amortizes to ~19 ms per save+restore call: PyTorch's scatter access
pattern across the (n_layers, max_requests, ...) state pool is far
slower than a bandwidth-bound contiguous copy of the same buffer.

Switch save and restore to use ``tensor.copy_(...)`` over the full
shadow / live state pools. Inactive slot bytes get copied too, but
that is harmless: any future use of a free slot is preceded by a
zeroing pass via ``_execute_pending_mamba_ops`` before the slot is
actually read.

Memory and behavior are unchanged from C5.5 — same shadow allocation,
same gating, same engine integration. The save/drop/restore state
becomes a single bool (``_spec_mamba_saved``) instead of a saved
index tensor.

Verification: 140/140 tests/unit_tests/inference/engines/
test_dynamic_engine.py pass. Throughput regression on the nanov3
benchmark to be measured in the follow-up run.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The launch-ahead loop landed in C5.4-C5.6 was correct under add/pause/
evict but signaled a rejection on every finished request. With staggered
request finishes (e.g., 100 prompts arriving over a 1 s window with
identical generation length), this dropped speculation onto the serial
fallback once per finish — and on the second-tier nanov3 trace this
meant "speculate / serial / speculate / serial" rather than a sustained
chain. The original plan's design called for finishes to flow through
without rejection (the finished slot keeps consuming compute until a
periodic sync); my implementation instead rejected on every finish,
which negated the benefit.

The cleaner fix here: detect "would any request finish this step" on
CPU before issuing the speculative pre-launch, and skip the pre-launch
when the answer is yes. The bookkeep then compacts the batch as before
(no deferred-finish complexity needed in update_requests), and the
*next* iteration starts fresh. Cost: one missed pre-launch per
finish-iter (vs. one wasted forward + one serial drain in the old
design).

Engine changes:
- Add `_would_any_finish(sample_cpu, active_count)`: mirrors the
  active-request-mask computation done inside
  `_dynamic_step_context_bookkeeping` (sample == termination_id, length
  >= max_sequence_length, optional stop-word callback).
- `_async_step_speculative` gates the pre-launch on
  `not _would_any_finish(...)` in addition to
  `can_speculate_decode_step()`. Finishes detected here skip the
  pre-launch entirely; bookkeep proceeds as normal.
- After `schedule_waiting_requests`, compare `total_request_count`
  pre/post; if a new request was added this iter, drain the now-stale
  pending pre-launch and continue without it. This catches the
  add-rejection case the old `len(waiting_request_ids) > 0` gate was
  guarding (over-conservatively).
- Drop `_signal_rejection()` on finishes — finishes are now caught
  upfront. Pause/evict still signal rejection (those happen *inside*
  update_requests during bookkeep, after the pre-launch is already
  on the stream).
- Remove `_finished_pending_count` and `_finished_sync_due()` — the
  periodic-sync backstop is no longer needed since finishes don't
  accumulate. The `finished_sync_period` config field is kept for
  backward compat in argparse / config but the engine ignores it.

Verification on the 12b synthetic benchmark (100 prompts, 32 generated
tokens, --no-load-checkpoint) shows the speculation chain sustaining
across ~26 consecutive decode iterations rather than firing only twice
in 60. Throughput on this workload is unchanged because T_fwd dominates
(210 ms vs ~2 ms gap), but the trace confirms the chain is now correct.
nanov3 benchmark numbers to follow.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…keep rejection

Replace the prior `_would_any_finish` pre-launch gate with a finish detection
in the post-bookkeep rejection check. The gate cost a forward+sample on every
iter just before a request would terminate; the new path issues the pre-launch
unconditionally, lets bookkeep observe the finish, and signals rejection so
the next iter's serial path drains the wasted forward (one wasted forward per
finish event, far cheaper than serial fallback every iter).

Stream order per iter (steady state):
  sample_N → D2H_N → H2D_{N+1} → forward_{N+1} → sample_{N+1}

The async D2H of sample_N is queued on the stream right after sample_N (via
non_blocking copy_ into pinned token_to_input_ids), with a CUDA event
recorded after. The pre-launch's H2D for step N+1 is then queued behind the
D2H, so it sees the correct sample_N value when reading the pinned slot.
Bookkeep clones out of the pinned view (so subsequent iters' D2H doesn't
mutate this iter's sample) and waits on the D2H event before reading.

Also: add LOAD_CHECKPOINT toggle to gpt_dynamic_inference_{357m,12b}.sh and
guard load_checkpoint() in get_model_for_inference behind args.load is not
None, so the inference scripts can run without weights for perf debugging.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_requests

When _async_step_speculative was entered (outer can_speculate=True at gate)
but the inner can_speculate check failed (e.g., schedule_waiting_requests
added a prefill request making is_decode_only=False), no pre-launch fired
and _pending_launch_state stayed None. If bookkeep then detected a finish
event, _signal_rejection set _rejection_pending=True. The next iter's
serial path only drains+clears when _pending_launch_state is not None, so
the rejection flag stuck and the engine ran the serial path forever.

In production (run_full_benchmark.sh nano-v3, 32k input × 64 reqs × 512
output via chunked prefill + continuous batching), this caused the
speculative path to fire on only 1% of decode iters (16/1551) — the rest
were silently serial. Hence the trace showing init_attn_state and
update_requests fully in the GPU-idle gap with zero overlap.

Two fixes:
  1. _async_step_speculative: only signal_rejection when a pre-launch is
     actually in flight. No pre-launch → nothing to drain → don't poison
     the flag.
  2. _async_step_serial: always clear_rejection on entry (even without a
     pending launch to drain). Belt-and-suspenders — once we're back in
     serial, the flag has served its purpose.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 30, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant