diff --git a/Source/Cmlx/include/mlx/c/fast.h b/Source/Cmlx/include/mlx/c/fast.h index dfdfad29..1b2f2f67 100644 --- a/Source/Cmlx/include/mlx/c/fast.h +++ b/Source/Cmlx/include/mlx/c/fast.h @@ -234,6 +234,18 @@ int mlx_fast_pread_into( const char* tensor_name, uint32_t expert_index); +// Like mlx_fast_pread_into, but writes the expert's bytes into the dst buffer +// starting at byte offset `dst_offset`. Reads exactly `bytes_per_expert` bytes +// (NOT the whole dst array). Use this to populate one slot of a stacked +// `[N_slots, ..., ...]` buffer, where `dst_offset = slot * bytes_per_expert`. +// Bounds check: dst_offset + bytes_per_expert <= dst.nbytes. +int mlx_fast_pread_into_offset( + mlx_array dst, + const char* safetensors_path, + const char* tensor_name, + uint32_t expert_index, + size_t dst_offset); + // mlx_fast_submit_prefetch (PAPPS Background Worker) int mlx_fast_submit_prefetch( const char* safetensors_path, diff --git a/Source/Cmlx/mlx-c/mlx/c/fast.cpp b/Source/Cmlx/mlx-c/mlx/c/fast.cpp index c4af7717..1e07f26f 100644 --- a/Source/Cmlx/mlx-c/mlx/c/fast.cpp +++ b/Source/Cmlx/mlx-c/mlx/c/fast.cpp @@ -1120,3 +1120,71 @@ extern "C" int mlx_fast_pread_into( } return 0; } + +// mlx_fast_pread_into_offset — variant that writes ONE expert into a slot of +// a stacked destination buffer. Used by SwitchGLU's stacked-buffer fast path +// (TEND_MOE_STACKED=1) to avoid `MLX.concatenated` cost when fusing per-expert +// matmuls into a single gatherQuantizedMM dispatch. +// +// dst_offset is bytes (not elements). Reads exactly `bytes_per_expert` bytes +// from the safetensors file at the requested expert index, into +// `dst.data() + dst_offset`. Bounds check (overflow-safe): +// dst_offset <= dst.nbytes() && bytes_per_expert <= dst.nbytes() - dst_offset. +// +// PAPPS fast path: if a background worker already preloaded this expert +// (cache_id = path|tname_, which is independent of dst_offset), +// take it via try_take() and memcpy into the slot, skipping the synchronous +// pread. Caller is expected to issue mlx_fast_submit_prefetch ahead of time +// (e.g. at last-token routing) to populate the PAPPS cache. +extern "C" int mlx_fast_pread_into_offset( + mlx_array dst, + const char* safetensors_path, + const char* tensor_name, + uint32_t expert_index, + size_t dst_offset) { + try { + std::string path(safetensors_path); + std::string tname(tensor_name); + std::string key = path + "|" + tname; + + STPReadEntry entry = get_safetensors_entry(path, tname, key); + + auto& arr = mlx_array_get_(dst); + void* base = const_cast(static_cast(arr.data())); + if (!base) throw std::runtime_error("[pread_into_offset] dst has no data pointer — call eval() first"); + size_t total_nbytes = arr.nbytes(); + size_t bpe = entry.bytes_per_expert; + if (dst_offset > total_nbytes || bpe > total_nbytes - dst_offset) { + throw std::runtime_error( + "[pread_into_offset] dst_offset (" + std::to_string(dst_offset) + + ") + bytes_per_expert (" + std::to_string(bpe) + + ") > dst.nbytes (" + std::to_string(total_nbytes) + ")"); + } + void* slot_buf = static_cast(base) + dst_offset; + off_t file_offset = static_cast(entry.data_start + (size_t)expert_index * bpe); + + // PAPPS fast path: try to absorb a previously-submitted prefetch. + // cache_id is keyed on (path,tname,file_offset) — same as full-buffer + // variant — so a single submit_prefetch call serves both consumers. + std::string cache_id = key + "_" + std::to_string(file_offset); + bool hit = false; + { + std::lock_guard lock(global_papps_mutex); + if (global_papps_queue) { + hit = global_papps_queue->try_take(cache_id, slot_buf, bpe); + } + } + if (hit) { + return 0; // memcpy from PAPPS cache complete; no syscall + } + + // Cache miss — synchronous pread into the slot. + ssize_t result = pread(entry.fd, slot_buf, bpe, file_offset); + if (result < 0 || (size_t)result != bpe) + throw std::runtime_error("[pread_into_offset] pread failed: got " + std::to_string(result) + " of " + std::to_string(bpe)); + } catch (std::exception& e) { + mlx_error(e.what()); + return 1; + } + return 0; +} diff --git a/Source/Cmlx/mlx-c/mlx/c/fast.h b/Source/Cmlx/mlx-c/mlx/c/fast.h index 4b1a4025..a25d14d8 100644 --- a/Source/Cmlx/mlx-c/mlx/c/fast.h +++ b/Source/Cmlx/mlx-c/mlx/c/fast.h @@ -249,6 +249,18 @@ int mlx_fast_pread_into( const char* tensor_name, uint32_t expert_index); +// Like mlx_fast_pread_into, but writes the expert's bytes into the dst buffer +// starting at byte offset `dst_offset`. Reads exactly `bytes_per_expert` bytes +// (NOT the whole dst array). Use this to populate one slot of a stacked +// `[N_slots, ..., ...]` buffer, where `dst_offset = slot * bytes_per_expert`. +// Bounds check: dst_offset + bytes_per_expert <= dst.nbytes. +int mlx_fast_pread_into_offset( + mlx_array dst, + const char* safetensors_path, + const char* tensor_name, + uint32_t expert_index, + size_t dst_offset); + /**@}*/ // ── SSD Flash-Stream metrics snapshot ──────────────────────────────────────── diff --git a/Source/MLX/MLXFast.swift b/Source/MLX/MLXFast.swift index eb6bdbbe..41dccd23 100644 --- a/Source/MLX/MLXFast.swift +++ b/Source/MLX/MLXFast.swift @@ -376,6 +376,31 @@ public enum MLXFast { } } + /// Like `preadInto`, but writes the expert's bytes into the destination at + /// byte-offset `dstOffset`. Reads exactly `bytes_per_expert` bytes (the + /// safetensors entry's per-expert slab size), NOT the whole dst. + /// + /// Use when you have a stacked `[N_slots, ..., ...]` MLXArray and want to + /// populate slot `k` via `dstOffset = k * bytesPerExpert`. Lets a single + /// `gatherQuantizedMM` call replace a per-expert loop, eliminating both + /// the per-expert kernel-launch overhead and the `MLX.concatenated` Metal + /// copy that would otherwise be needed to fuse N independent buffers. + @discardableResult + public static func preadIntoOffset( + _ dst: MLXArray, + safetensorsPath: String, + tensorName: String, + expertIndex: UInt32, + dstOffset: Int + ) -> Int32 { + precondition(dstOffset >= 0, "dstOffset must be non-negative") + return safetensorsPath.withCString { pathPtr in + tensorName.withCString { namePtr in + mlx_fast_pread_into_offset(dst.ctx, pathPtr, namePtr, expertIndex, dstOffset) + } + } + } + /// Submits an asynchronous background prefetch for a specific expert's weights. /// The fetch is handled by a persistent C++ background thread and placed in a unified memory arena. public static func pappsPrefetch(