From 359978ad579c0fb704bc20d312b5e782db06df02 Mon Sep 17 00:00:00 2001 From: Eric Lake Date: Sun, 26 Apr 2026 15:20:44 -0700 Subject: [PATCH 1/2] feat(fast): preadIntoOffset for stacked-buffer MoE consumers + PAPPS try_take MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds `mlx_fast_pread_into_offset`, a variant of `mlx_fast_pread_into` that writes ONE expert's bytes into a destination MLXArray at a given byte offset rather than overwriting the whole array. This unlocks stacked-buffer fast paths in MoE consumers (e.g. mlx-swift-lm's SwitchGLU): allocate one `[CACHE_SLOTS, intermediate, hidden]` weight buffer per layer and populate slots in place via `slot * bytes_per_expert` byte offsets, instead of allocating one MLXArray per cached expert and concatenating at compute time. Bounds-checked: `dst_offset + bytes_per_expert <= dst.nbytes()`. Reads exactly `bytes_per_expert` (not the whole dst). PAPPS fast path: when the global PAPPS prefetcher is enabled, attempts `try_take()` from the PAPPS cache before falling through to synchronous `pread`. cache_id is keyed on (path, tname, file_offset) — independent of dst_offset — so a single `mlx_fast_submit_prefetch()` call serves both this offset variant and the existing full-buffer variant (`mlx_fast_pread_into`). Public API: - `mlx_fast_pread_into_offset(dst, path, tname, expert_index, dst_offset)` in `mlx-c/mlx/c/fast.h` and mirrored in `include/mlx/c/fast.h`. - Swift wrapper `MLXFast.preadIntoOffset(...)` in `Source/MLX/MLXFast.swift`. No behavior change for existing consumers; new function is additive. Co-Authored-By: Claude Opus 4.7 (1M context) --- Source/Cmlx/include/mlx/c/fast.h | 12 ++++++ Source/Cmlx/mlx-c/mlx/c/fast.cpp | 68 ++++++++++++++++++++++++++++++++ Source/Cmlx/mlx-c/mlx/c/fast.h | 12 ++++++ Source/MLX/MLXFast.swift | 24 +++++++++++ 4 files changed, 116 insertions(+) diff --git a/Source/Cmlx/include/mlx/c/fast.h b/Source/Cmlx/include/mlx/c/fast.h index dfdfad29..1b2f2f67 100644 --- a/Source/Cmlx/include/mlx/c/fast.h +++ b/Source/Cmlx/include/mlx/c/fast.h @@ -234,6 +234,18 @@ int mlx_fast_pread_into( const char* tensor_name, uint32_t expert_index); +// Like mlx_fast_pread_into, but writes the expert's bytes into the dst buffer +// starting at byte offset `dst_offset`. Reads exactly `bytes_per_expert` bytes +// (NOT the whole dst array). Use this to populate one slot of a stacked +// `[N_slots, ..., ...]` buffer, where `dst_offset = slot * bytes_per_expert`. +// Bounds check: dst_offset + bytes_per_expert <= dst.nbytes. +int mlx_fast_pread_into_offset( + mlx_array dst, + const char* safetensors_path, + const char* tensor_name, + uint32_t expert_index, + size_t dst_offset); + // mlx_fast_submit_prefetch (PAPPS Background Worker) int mlx_fast_submit_prefetch( const char* safetensors_path, diff --git a/Source/Cmlx/mlx-c/mlx/c/fast.cpp b/Source/Cmlx/mlx-c/mlx/c/fast.cpp index c4af7717..e3cdf027 100644 --- a/Source/Cmlx/mlx-c/mlx/c/fast.cpp +++ b/Source/Cmlx/mlx-c/mlx/c/fast.cpp @@ -1120,3 +1120,71 @@ extern "C" int mlx_fast_pread_into( } return 0; } + +// mlx_fast_pread_into_offset — variant that writes ONE expert into a slot of +// a stacked destination buffer. Used by SwitchGLU's stacked-buffer fast path +// (TEND_MOE_STACKED=1) to avoid `MLX.concatenated` cost when fusing per-expert +// matmuls into a single gatherQuantizedMM dispatch. +// +// dst_offset is bytes (not elements). Reads exactly `bytes_per_expert` bytes +// from the safetensors file at the requested expert index, into +// `dst.data() + dst_offset`. Bounds check: dst_offset + bytes_per_expert <= +// dst.nbytes(). +// +// PAPPS fast path: if a background worker already preloaded this expert +// (cache_id = path|tname_, which is independent of dst_offset), +// take it via try_take() and memcpy into the slot, skipping the synchronous +// pread. Caller is expected to issue mlx_fast_submit_prefetch ahead of time +// (e.g. at last-token routing) to populate the PAPPS cache. +extern "C" int mlx_fast_pread_into_offset( + mlx_array dst, + const char* safetensors_path, + const char* tensor_name, + uint32_t expert_index, + size_t dst_offset) { + try { + std::string path(safetensors_path); + std::string tname(tensor_name); + std::string key = path + "|" + tname; + + STPReadEntry entry = get_safetensors_entry(path, tname, key); + + auto& arr = mlx_array_get_(dst); + void* base = const_cast(static_cast(arr.data())); + if (!base) throw std::runtime_error("[pread_into_offset] dst has no data pointer — call eval() first"); + size_t total_nbytes = arr.nbytes(); + size_t bpe = entry.bytes_per_expert; + if (dst_offset + bpe > total_nbytes) { + throw std::runtime_error( + "[pread_into_offset] dst_offset (" + std::to_string(dst_offset) + + ") + bytes_per_expert (" + std::to_string(bpe) + + ") > dst.nbytes (" + std::to_string(total_nbytes) + ")"); + } + void* slot_buf = static_cast(base) + dst_offset; + off_t file_offset = static_cast(entry.data_start + (size_t)expert_index * bpe); + + // PAPPS fast path: try to absorb a previously-submitted prefetch. + // cache_id is keyed on (path,tname,file_offset) — same as full-buffer + // variant — so a single submit_prefetch call serves both consumers. + std::string cache_id = key + "_" + std::to_string(file_offset); + bool hit = false; + { + std::lock_guard lock(global_papps_mutex); + if (global_papps_queue) { + hit = global_papps_queue->try_take(cache_id, slot_buf, bpe); + } + } + if (hit) { + return 0; // memcpy from PAPPS cache complete; no syscall + } + + // Cache miss — synchronous pread into the slot. + ssize_t result = pread(entry.fd, slot_buf, bpe, file_offset); + if (result < 0 || (size_t)result != bpe) + throw std::runtime_error("[pread_into_offset] pread failed: got " + std::to_string(result) + " of " + std::to_string(bpe)); + } catch (std::exception& e) { + mlx_error(e.what()); + return 1; + } + return 0; +} diff --git a/Source/Cmlx/mlx-c/mlx/c/fast.h b/Source/Cmlx/mlx-c/mlx/c/fast.h index 4b1a4025..a25d14d8 100644 --- a/Source/Cmlx/mlx-c/mlx/c/fast.h +++ b/Source/Cmlx/mlx-c/mlx/c/fast.h @@ -249,6 +249,18 @@ int mlx_fast_pread_into( const char* tensor_name, uint32_t expert_index); +// Like mlx_fast_pread_into, but writes the expert's bytes into the dst buffer +// starting at byte offset `dst_offset`. Reads exactly `bytes_per_expert` bytes +// (NOT the whole dst array). Use this to populate one slot of a stacked +// `[N_slots, ..., ...]` buffer, where `dst_offset = slot * bytes_per_expert`. +// Bounds check: dst_offset + bytes_per_expert <= dst.nbytes. +int mlx_fast_pread_into_offset( + mlx_array dst, + const char* safetensors_path, + const char* tensor_name, + uint32_t expert_index, + size_t dst_offset); + /**@}*/ // ── SSD Flash-Stream metrics snapshot ──────────────────────────────────────── diff --git a/Source/MLX/MLXFast.swift b/Source/MLX/MLXFast.swift index eb6bdbbe..8ed478be 100644 --- a/Source/MLX/MLXFast.swift +++ b/Source/MLX/MLXFast.swift @@ -376,6 +376,30 @@ public enum MLXFast { } } + /// Like `preadInto`, but writes the expert's bytes into the destination at + /// byte-offset `dstOffset`. Reads exactly `bytes_per_expert` bytes (the + /// safetensors entry's per-expert slab size), NOT the whole dst. + /// + /// Use when you have a stacked `[N_slots, ..., ...]` MLXArray and want to + /// populate slot `k` via `dstOffset = k * bytesPerExpert`. Lets a single + /// `gatherQuantizedMM` call replace a per-expert loop, eliminating both + /// the per-expert kernel-launch overhead and the `MLX.concatenated` Metal + /// copy that would otherwise be needed to fuse N independent buffers. + @discardableResult + public static func preadIntoOffset( + _ dst: MLXArray, + safetensorsPath: String, + tensorName: String, + expertIndex: UInt32, + dstOffset: Int + ) -> Int32 { + safetensorsPath.withCString { pathPtr in + tensorName.withCString { namePtr in + mlx_fast_pread_into_offset(dst.ctx, pathPtr, namePtr, expertIndex, dstOffset) + } + } + } + /// Submits an asynchronous background prefetch for a specific expert's weights. /// The fetch is handled by a persistent C++ background thread and placed in a unified memory arena. public static func pappsPrefetch( From 15da2b517b8ac0e0429be9a9d47f65ed6ee9af02 Mon Sep 17 00:00:00 2001 From: Aegis AI Assistant Date: Sun, 26 Apr 2026 18:54:05 -0700 Subject: [PATCH 2/2] fix(preadIntoOffset): overflow-safe bounds check + Swift precondition guard MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses Copilot review on #10: 1. (🔴 must-fix) Rewrite size_t bounds check in overflow-safe form: `dst_offset > total_nbytes || bpe > total_nbytes - dst_offset` instead of `dst_offset + bpe > total_nbytes` which wraps on near-SIZE_MAX dst_offset values. 2. (🟡 should-fix) Add `precondition(dstOffset >= 0)` in the Swift wrapper to catch negative Int values before they are silently reinterpreted as huge unsigned size_t on the C side. No behavior change for valid inputs. --- Source/Cmlx/mlx-c/mlx/c/fast.cpp | 6 +++--- Source/MLX/MLXFast.swift | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/Source/Cmlx/mlx-c/mlx/c/fast.cpp b/Source/Cmlx/mlx-c/mlx/c/fast.cpp index e3cdf027..1e07f26f 100644 --- a/Source/Cmlx/mlx-c/mlx/c/fast.cpp +++ b/Source/Cmlx/mlx-c/mlx/c/fast.cpp @@ -1128,8 +1128,8 @@ extern "C" int mlx_fast_pread_into( // // dst_offset is bytes (not elements). Reads exactly `bytes_per_expert` bytes // from the safetensors file at the requested expert index, into -// `dst.data() + dst_offset`. Bounds check: dst_offset + bytes_per_expert <= -// dst.nbytes(). +// `dst.data() + dst_offset`. Bounds check (overflow-safe): +// dst_offset <= dst.nbytes() && bytes_per_expert <= dst.nbytes() - dst_offset. // // PAPPS fast path: if a background worker already preloaded this expert // (cache_id = path|tname_, which is independent of dst_offset), @@ -1154,7 +1154,7 @@ extern "C" int mlx_fast_pread_into_offset( if (!base) throw std::runtime_error("[pread_into_offset] dst has no data pointer — call eval() first"); size_t total_nbytes = arr.nbytes(); size_t bpe = entry.bytes_per_expert; - if (dst_offset + bpe > total_nbytes) { + if (dst_offset > total_nbytes || bpe > total_nbytes - dst_offset) { throw std::runtime_error( "[pread_into_offset] dst_offset (" + std::to_string(dst_offset) + ") + bytes_per_expert (" + std::to_string(bpe) + diff --git a/Source/MLX/MLXFast.swift b/Source/MLX/MLXFast.swift index 8ed478be..41dccd23 100644 --- a/Source/MLX/MLXFast.swift +++ b/Source/MLX/MLXFast.swift @@ -393,7 +393,8 @@ public enum MLXFast { expertIndex: UInt32, dstOffset: Int ) -> Int32 { - safetensorsPath.withCString { pathPtr in + precondition(dstOffset >= 0, "dstOffset must be non-negative") + return safetensorsPath.withCString { pathPtr in tensorName.withCString { namePtr in mlx_fast_pread_into_offset(dst.ctx, pathPtr, namePtr, expertIndex, dstOffset) }