Skip to content

perf(model_loader): multi-threaded safetensors weight loading#287

Open
yuanqingz wants to merge 5 commits into
lightseekorg:mainfrom
yuanqingz:perf/multi-threaded-weight-loader
Open

perf(model_loader): multi-threaded safetensors weight loading#287
yuanqingz wants to merge 5 commits into
lightseekorg:mainfrom
yuanqingz:perf/multi-threaded-weight-loader

Conversation

@yuanqingz
Copy link
Copy Markdown

@yuanqingz yuanqingz commented May 28, 2026

Parallelize the safetensors loader in weight_utils.safetensors_weights_iterator. When prefetch_num_threads > 1, dispatch shard load_file() calls to a ThreadPoolExecutor and yield tensors as soon as each shard completes. A sliding window keeps at most max_workers files resident in CPU memory at any time. After loading, each tensor is cloned in the worker thread so the bytes are RAM-materialized before being handed to the consumer — without this, the per-tensor access in the downstream model.load_weights() loop page-faults back through NFS and the parallel I/O is silently serialized.

Mirrors vLLM's multi_thread_safetensors_weights_iterator (in vllm/model_executor/model_loader/weight_utils.py).

Motivation

The existing serial path is bottlenecked by per-file open / read latency. On Kimi K2.5-NVFP4 (119 safetensors shards, 553 GiB) on a 4× B200 host backed by a shared NFS appliance, the serial loader takes ~20 min to materialize the checkpoint on cold client cache. With this change the same workload finishes in ~10 min on the same hardware.

Implementation

Four commits on this branch:

  1. perf(model_loader): multi-threaded safetensors weight loading — adds _multi_thread_safetensors_weights_iterator, wires it into safetensors_weights_iterator when prefetch_num_threads > 1.
  2. fix(model_loader): cap concurrent safetensors shard loads — switches the loader from as_completed(generator) (which concurrent.futures drains eagerly into a queue) to an explicit sliding window: submit the initial max_workers files, then _submit_next() after each completion. Cancel pending futures on BaseException, shutdown executor in finally. Guarantees ≤ max_workers × shard_size resident at any time.
  3. fix(model_loader): force materialize tensors after parallel safetensors loadsafetensors.torch.load_file(device="cpu") returns mmap-backed tensors. With the serial loader the consumer accesses pages on the same fd that just opened them and NFS prefetch hides the latency. With the parallel loader the consumer reaches file B's tensors after worker thread B's prefetch heuristic has gone cold, so each per-tensor access triggers an NFS page fault. Cloning each tensor inside the worker thread moves the page faults to the parallel I/O phase. Trade-off: peak CPU memory ≈ 2 × max_workers × shard_size (still well under any production host RAM budget — measured 61 GiB peak resident in the 8-worker test, see below).
  4. perf(model_loader): bump default weight_loader_prefetch_num_threads to 8 — for very large checkpoints (>500 GiB / >100 shards) 4 workers leaves throughput on the table.

Existing users automatically pick up the speedup with no flag change required. Opt out with --weight-loader-prefetch-num-threads 1.

Test Plan — measured speedup (replicated)

Standalone loader-only timing harness, Kimi-K2.5-NVFP4 read directly from NFS — no tmpfs / /dev/shm staging — exercising weight_utils.safetensors_weights_iterator end-to-end with .clone() per tensor inside the consumer loop so we time actual NFS bandwidth, not lazy mmap setup. Each run is a fresh host allocation = cold client page cache. Server-side NFS cache state was similar across all four runs (submitted within minutes of each other against the same NFS volume).

Run Variant Iterator wall Aggregate BW
Baseline #1 — stock serial threads=1 22:30 (1350 s) 417 MB/s
Baseline #2 — stock serial threads=1 19:06 (1146 s) 491 MB/s
Baseline avg ~20:48 (1248 s) ~454 MB/s
PR — multi-thread + clone workers=8 10:37 (637 s) 884 MB/s
PR — multi-thread + clone workers=8 10:20 (620 s) 909 MB/s
PR avg ~10:29 (629 s) ~896 MB/s
Speedup ~1.98×

Observations:

  • Two PR runs replicated within 1.4% (10:20 vs 10:37) — the multi-thread path saturates the NFS volume's read bandwidth (~900 MB/s) so external load drift doesn't move the needle.
  • Two baseline runs spread 9% (22:30 vs 19:06) — at ~450 MB/s the server has headroom and other tenants' I/O measurably perturbs the timing.
  • Peak resident memory at 8 workers + clone: 61 GiB per process (measured via resource.getrusage). Well within any production host RAM budget.

Correctness end-to-end on the engine

Verified with ts serve (--weight-loader-prefetch-num-threads 8, attn_tp4_moe_tp4 + EAGLE3 spec-dec, weights served from NFS). On all 4 TP ranks the full weight-loading pipeline completes on RAM-materialized tensors:

  1. Multi-thread loading shards (workers=8): 100% Completed | 119/119 [12:42<00:00] (this PR's tqdm).
  2. ~1:42 later, model.load_weights() returns on all ranks — the per-expert loop in DeepseekV3.load_weights → moe/checkpoint/loader.load → load_model_weight → _load_w13/_load_w2 proceeds normally on RAM-materialized tensors. Without the clone in commit (3) this phase hangs indefinitely (verified by cancelling a control run at ~28 min with model.load_weights() still not returned despite the iterator reporting 100% at +71 s — that 71 s was lazy mmap, not real I/O).
  3. process_weights_after_loading loop completes (10 quant/self method calls for attn_tp4_moe_tp4).
  4. Load weight end. type=Eagle3DeepseekV2ForCausalLM, dtype=torch.bfloat16, avail mem=32.51 GB (weight_loader.py:116) on all 4 ranks.

A separate engine-side crash in profile_available_cache_memory_bytes → torch.distributed.all_reduce (Gloo "Connection closed by peer") reproduced on both 4- and 8-worker full-engine runs — that's post-Load weight end and does not involve the loader code path; it is being investigated separately (suspect: gateway startup timeout or peer rank exit during long warmup).

Unit test

Added test/runtime/test_weight_loader_prefetch.py (covers sliding-window correctness, exception cancellation, max_workers RSS bound).

Notes for reviewers

  • The clone in commit (3) doubles peak resident memory while a shard is in flight. Worst case for K2.5: 8 workers × ~5 GiB × 2 = ~80 GiB per rank, × 4 TP ranks = ~320 GiB. Measured peak per process in the standalone harness: 61 GiB (less than theoretical worst case because consumer drains tensors before all 8 workers peak together).
  • prefetch_checkpoint_files() (the older page-cache warmer) is left in place for prefetch=True callers on the serial path. The new parallel path supersedes it for the common case.
  • No GPU/CUDA changes; this is pure CPU + I/O.

When prefetch_num_threads > 1, dispatch shard load_file() calls to a
ThreadPoolExecutor and yield tensors as soon as each shard completes.
Submissions are gated lazily by a generator expression so at most
max_workers shards are resident in CPU memory at any time.

Motivation: on NFS-backed checkpoints with many shards, the existing
serial path is bottlenecked by per-file open/read latency. For Kimi
K2.5-NVFP4 (119 safetensors shards, 553 GB) we measured ~50 min for
model loading on B200 4-GPU; vLLM completes the equivalent path in
~6 min via a similar multi-threaded iterator (vllm/model_executor/
model_loader/weight_utils.py:multi_thread_safetensors_weights_iterator).

The existing prefetch_checkpoint_files() helper that warms OS page
cache is left in place for callers that need it explicitly; the
serial path still calls it when prefetch=True. The new parallel
path supersedes it for the common case.

Wired into ServerArgs.weight_loader_prefetch_num_threads (default 4),
so existing users automatically pick up the speedup with no flag
change required. To opt out: --weight-loader-prefetch-num-threads 1.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@yuanqingz yuanqingz requested a review from a team as a code owner May 28, 2026 02:31
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1f25e0fae5

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

# Lazy submission cap: only ``max_workers`` files in flight at once.
futures = (executor.submit(_load_file, st_file) for st_file in hf_weights_files)
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Cap shard submissions instead of queuing every file

With the default weight_loader_prefetch_num_threads=4, any checkpoint with multiple safetensors shards now takes this path, but concurrent.futures.as_completed() materializes its input (set(fs)) before yielding, so this generator is consumed immediately and submits every shard at once. The executor only runs max_workers tasks concurrently, but completed futures keep their full state_dict results until the iterator reaches them, so large multi-shard checkpoints can still accumulate many loaded shards in CPU memory despite the intended cap of max_workers resident files.

Useful? React with 👍 / 👎.

@yuanqingz
Copy link
Copy Markdown
Author

@codex review

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex Review: Didn't find any major issues. More of your lovely PRs please.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

yuanqingz and others added 2 commits May 27, 2026 20:37
…rs load

safetensors.torch.load_file(device="cpu") returns mmap-backed tensors.
With the serial loader, the consumer's per-tensor access overlaps with
NFS prefetch on the same file descriptor → no observable issue.

With the multi-threaded loader, by the time the consumer reaches file
B's tensors, the prefetch heuristic for B's fd is cold (a different
worker thread opened it), so each per-tensor access triggers an NFS
page fault. Net effect: load_file() parallelizes fine (3x speedup on
the safetensors phase alone for K2.5-NVFP4), but the downstream
DeepseekV3.load_weights() per-expert loop becomes ~13x slower than on
RAM-resident storage (>16 min vs 75 sec on /dev/shm).

Cloning each tensor inside the worker thread forces the page faults
to happen there, while we already pay NFS read cost as part of the
parallel I/O phase. Trade-off:

- +1 RAM copy of each shard while in flight (memory peak: 2x max_workers
  shard size, still well under host RAM for typical large MoE models).
- Per-worker wall time goes up (mmap-touch → full read + copy), but the
  parallelism still pays.
- Downstream consumer now only touches RAM-backed tensors.

Measured on Kimi K2.5-NVFP4 (553 GB, 119 shards) on Computelab B200
node with 4-thread loader against /home (NFS):

  - Before clone():  safetensors 100% at +71s, model.load_weights()
    still hung at +16 min — never reaches process_weights_after_loading.
  - After clone():   <to be measured in the next bench run>

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
For a 553 GB / 119-shard model (Kimi K2.5-NVFP4) on a slow NFS
backend, the multi-threaded safetensors loader needs more than 4
worker threads to saturate the aggregate read bandwidth. Measured on
Computelab B200 node (against /home NFS, ~750 MB/s per-file ceiling):

  - 4 workers:   ~19 min for 553 GB (~485 MB/s aggregate)
  - 8 workers:   <to be measured next>

Going from 4 → 8 mostly affects very large checkpoints; small models
(< 10 GB) finish in either configuration in seconds. Peak CPU RAM
during load doubles vs the 4-worker default, but is still bounded
by ~2 × num_workers × shard_size = ~80 GB for typical 5 GB shards.
That's well under the host RAM of any production-class server.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f8c6219c93

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

decryption_key_file: str | None = None
weight_loader_prefetch_checkpoints: bool = False
weight_loader_prefetch_num_threads: int = 4
weight_loader_prefetch_num_threads: int = 8
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Align the default with its checked-in test

With this default changed to 8, the checked-in runtime test test_load_config_defaults_keep_prefetch_disabled still asserts LoadConfig().weight_loader_prefetch_num_threads == 4 in test/runtime/test_weight_loader_prefetch.py:44. In an environment with the project dependencies installed, this commit will fail that test before exercising the new threaded loader, so either keep the default at 4 or update the expected default consistently.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant