feat(memory-saver): optional CPU staging for round-trip weight preservation#275
Draft
qywu wants to merge 1 commit into
Draft
feat(memory-saver): optional CPU staging for round-trip weight preservation#275qywu wants to merge 1 commit into
qywu wants to merge 1 commit into
Conversation
…vation #272's /release_memory_occupation truthfully releases GPU memory but the contents are gone — torch_memory_saver.pause() preserves virtual addresses only, not data. /resume_memory_occupation gets zeroed pages back, so any caller that wants the model to work after a release/resume cycle has to re-read the checkpoint from disk. For RLHF train↔serve handoff and similar "pause inference, return GPU, resume inference" flows that's tens of GiB of disk I/O on every cycle. Adds an opt-in CPU staging step: POST /release_memory_occupation {"stage_to_cpu": true} - Before saver.pause(), copy every param and buffer of the (target and draft) model into a pre-allocated pinned host buffer. - After saver.resume(), copy them back. On Qwen2-72B bf16 (~145 GiB) staging round-trip = ~12 s over PCIe Gen4 x16 vs ~60-180 s to re-read the same weights from network storage. Changes: - New: tokenspeed/runtime/engine/memory_occupation_manager.py with the MemoryOccupationManager class — pin/unpin host buffers, drive saver.pause()/resume(), and reuse buffers across cycles. - io_struct.py: add stage_to_cpu: bool = False to ReleaseMemoryOccupationReqInput. - request_handler.py: replace direct memory_saver calls with the manager so the staging path runs in the scheduler process where the model_runner is live. - event_loop.py: construct the manager with the real target/draft model_runner and pass it into RequestHandler. - http_server.py: parse stage_to_cpu from the JSON body. Trade-offs: - Host RAM hold = ~sizeof(model) for the duration of the release. - Staging adds a few seconds to the release path; without it /release completes in ~1 s. - Does not stage KV cache or request pools (those are scratch — the engine flushes them before any reasonable use of this endpoint). Stacked on top of #272. Signed-off-by: Qingyang Wu <willqywu@gmail.com>
This was referenced May 27, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacks on #272.
Summary
#272's
/release_memory_occupationtruthfully releases GPU memory but the contents are gone —torch_memory_saver.pause()preserves virtual addresses only, not data./resume_memory_occupationgets zeroed pages back, so any caller that wants the model to work after a release/resume cycle has to re-read the checkpoint from disk. For RLHF train↔serve handoff and similar "pause inference, return GPU, resume inference" flows that's tens of GiB of disk I/O on every cycle.This PR adds an opt-in CPU staging step:
saver.pause(), copy every param and buffer of the (target and draft) model into a pre-allocated pinned host buffer.saver.resume(), copy them back into the original GPU virtual addresses (so any CUDAGraph captures keep their argument pointers valid).Changes
tokenspeed/runtime/engine/memory_occupation_manager.py—MemoryOccupationManagerclass. Owns the pinned host buffers, drivessaver.pause()/resume(), and reuses staging buffers across cycles so a steady-state RLHF loop doesn't reallocate ~145 GiB of pinned host RAM every iteration.io_struct.py— addstage_to_cpu: bool = FalsetoReleaseMemoryOccupationReqInput.request_handler.py— replace directmemory_savercalls with the manager so the staging path runs in the scheduler process wheremodel_runner.modelis live.event_loop.py— construct the manager with the real target + draftmodel_runnerand pass it intoRequestHandler.http_server.py— parsestage_to_cpufrom the JSON body.Verification on H100 — round-trip correctness PASS
Measured on Qwen2-1.5B-Instruct,
gpu_memory_utilization=0.5,--attention-backend=triton,--enforce-eager:release(stage_to_cpu=true)resume()(restores from CPU)Generation parity:
"The capital of France is"→" ______.\nA. Paris\nB.""The capital of France is"→" ______.\nA. Paris\nB."Outputs match: True— round-trip preservation works end-to-end.For Qwen2-1.5B (~3 GiB bf16 weights), the 1.7 s release latency is dominated by the DtoH memcpy of weights into pinned host RAM. Extrapolating to Qwen2-72B (~145 GiB) at PCIe Gen4 x16 (~25 GiB/s pinned): release ≈ 6-8 s; resume ≈ 6-8 s. Compared to ~60-180 s to re-read a 145 GiB checkpoint from network storage, that's a ~10-20× speedup for the train↔serve handoff path.
Trade-offs
req_to_token_poolare scratch — the engine flushes outstanding requests before any reasonable use of this endpoint, so there's nothing in them worth preserving.tagsfield is plumbed but unused; a future PR can use it to stage only a subset (e.g. weights but not KV).Test plan
stage_to_cpu=true, resume, run inference; outputs match the pre-release run.stage_to_cpu=falsecontinues to behave as feat: expose POST /release_memory_occupation and /resume_memory_occupation #272 — release returns zeroed pages, caller must reload weights.Env caveats during testing (unrelated to this PR)
scheduler_metadatakwarg leak inMHAAttnBackendthat breaks--attention-backend=triton/fa4/flashinfer.import triton_kernels.matmulfailure caused by upstream module renames.Open questions
pin_memory(currently default-on)? On hosts without enough pinned-RAM budget, unpinned staging works at ~6-8 GiB/s vs. ~25 GiB/s pinned.