From 7501a0b0fc8d0a917049f6739a38cda91abcf993 Mon Sep 17 00:00:00 2001 From: Shiv Tyagi Date: Thu, 11 Jun 2026 17:07:41 +0000 Subject: [PATCH 1/2] fix(spur-cloud-api): adapt GPU allocation to CDI-based device API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit spur PR #262 replaced the flat `alloc.gpus` field with a generic `alloc.devices` map keyed by device class. Update `get_gpu_capacity` to read allocated GPUs from `alloc.devices["gpu"]` and cross-reference device IDs against total resources for type resolution. Bump tonic 0.12→0.14 and prost 0.13→0.14 to match spur-proto. --- Cargo.lock | 305 +++++++++++------------ Cargo.toml | 12 +- crates/spur-cloud-api/src/spur_client.rs | 249 ++++++++++++++++-- 3 files changed, 390 insertions(+), 176 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4b69ed9..c49286b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -180,7 +180,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", - "axum-core", + "axum-core 0.4.5", "base64", "bytes", "futures-util", @@ -190,7 +190,7 @@ dependencies = [ "hyper", "hyper-util", "itoa", - "matchit", + "matchit 0.7.3", "memchr", "mime", "percent-encoding", @@ -210,6 +210,31 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31b698c5f9a010f6573133b09e0de5408834d0c82f8d7475a89fc1867a71cd90" +dependencies = [ + "axum-core 0.5.6", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "itoa", + "matchit 0.8.4", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "sync_wrapper", + "tower 0.5.3", + "tower-layer", + "tower-service", +] + [[package]] name = "axum-core" version = "0.4.5" @@ -231,14 +256,32 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", +] + [[package]] name = "axum-extra" version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c794b30c904f0a1c2fb7740f7df7f7972dfaa14ef6f57cb6178dc63e5dca2f04" dependencies = [ - "axum", - "axum-core", + "axum 0.7.9", + "axum-core 0.4.5", "bytes", "cookie", "fastrand", @@ -911,19 +954,13 @@ dependencies = [ "futures-core", "futures-sink", "http", - "indexmap 2.14.0", + "indexmap", "slab", "tokio", "tokio-util", "tracing", ] -[[package]] -name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" - [[package]] name = "hashbrown" version = "0.15.5" @@ -1149,7 +1186,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.6.3", + "socket2", "tokio", "tower-service", "tracing", @@ -1294,16 +1331,6 @@ dependencies = [ "icu_properties", ] -[[package]] -name = "indexmap" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" -dependencies = [ - "autocfg", - "hashbrown 0.12.3", -] - [[package]] name = "indexmap" version = "2.14.0" @@ -1646,6 +1673,12 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "md-5" version = "0.10.6" @@ -1916,12 +1949,13 @@ dependencies = [ [[package]] name = "petgraph" -version = "0.7.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" +checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" dependencies = [ "fixedbitset", - "indexmap 2.14.0", + "hashbrown 0.15.5", + "indexmap", ] [[package]] @@ -2028,9 +2062,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.13.5" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +checksum = "528ac67416ff8646872a3c02cad9cc4ee5dc9f9540c9b10771855c95cb2e5ae1" dependencies = [ "bytes", "prost-derive", @@ -2038,19 +2072,20 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.13.5" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" +checksum = "03da047801ff44bb6a4d407d4860c05fd70bb81714e6b2f3812603d5b145b042" dependencies = [ "heck", "itertools", "log", "multimap", - "once_cell", "petgraph", "prettyplease", "prost", "prost-types", + "pulldown-cmark", + "pulldown-cmark-to-cmark", "regex", "syn", "tempfile", @@ -2058,9 +2093,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.13.5" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +checksum = "b570b25f7617e43d59005d0990ccb79e950a423952cea19671b7a876da390adf" dependencies = [ "anyhow", "itertools", @@ -2071,13 +2106,33 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.13.5" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" +checksum = "f94967dc7688f3054c7fac87473ffae4cc4c3904800e2d9f5b857246d8963b0a" dependencies = [ "prost", ] +[[package]] +name = "pulldown-cmark" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9f068eba8e7071c5f9511831b44f32c740d5adf574e990f946ddb53db2f314e" +dependencies = [ + "bitflags", + "memchr", + "unicase", +] + +[[package]] +name = "pulldown-cmark-to-cmark" +version = "22.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50793def1b900256624a709439404384204a5dc3a6ec580281bfaac35e882e90" +dependencies = [ + "pulldown-cmark", +] + [[package]] name = "quinn" version = "0.11.9" @@ -2091,7 +2146,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls", - "socket2 0.6.3", + "socket2", "thiserror 2.0.18", "tokio", "tracing", @@ -2107,7 +2162,7 @@ dependencies = [ "bytes", "getrandom 0.3.4", "lru-slab", - "rand 0.9.4", + "rand 0.9.3", "ring", "rustc-hash", "rustls", @@ -2128,9 +2183,9 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.6.3", + "socket2", "tracing", - "windows-sys 0.60.2", + "windows-sys 0.52.0", ] [[package]] @@ -2167,9 +2222,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.4" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" +checksum = "7ec095654a25171c2124e9e3393a930bddbffdc939556c914957a4c3e0a87166" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.5", @@ -2353,9 +2408,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.40" +version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ "log", "once_cell", @@ -2625,7 +2680,7 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.14.0", + "indexmap", "itoa", "ryu", "serde", @@ -2716,16 +2771,6 @@ dependencies = [ "serde", ] -[[package]] -name = "socket2" -version = "0.5.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - [[package]] name = "socket2" version = "0.6.3" @@ -2761,7 +2806,7 @@ version = "0.3.0" dependencies = [ "anyhow", "argon2", - "axum", + "axum 0.7.9", "axum-extra", "base64", "bytes", @@ -2808,12 +2853,13 @@ dependencies = [ [[package]] name = "spur-proto" version = "0.3.0" -source = "git+https://github.com/ROCm/spur.git?tag=v0.3.0#41de8e2fe09b9930bd570dc981b3eec3de23e768" +source = "git+https://github.com/ROCm/spur.git?rev=5781872920935f38b9b9c9734bbaefb29d9a6fb3#5781872920935f38b9b9c9734bbaefb29d9a6fb3" dependencies = [ "prost", "prost-types", "tonic", - "tonic-build", + "tonic-prost", + "tonic-prost-build", ] [[package]] @@ -2848,7 +2894,7 @@ dependencies = [ "futures-util", "hashbrown 0.15.5", "hashlink", - "indexmap 2.14.0", + "indexmap", "log", "memchr", "once_cell", @@ -3204,7 +3250,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2 0.6.3", + "socket2", "tokio-macros", "windows-sys 0.61.2", ] @@ -3294,7 +3340,7 @@ version = "0.22.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" dependencies = [ - "indexmap 2.14.0", + "indexmap", "serde", "serde_spanned", "toml_datetime", @@ -3310,13 +3356,12 @@ checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" [[package]] name = "tonic" -version = "0.12.3" +version = "0.14.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" +checksum = "ac2a5518c70fa84342385732db33fb3f44bc4cc748936eb5833d2df34d6445ef" dependencies = [ - "async-stream", "async-trait", - "axum", + "axum 0.8.9", "base64", "bytes", "h2", @@ -3328,11 +3373,11 @@ dependencies = [ "hyper-util", "percent-encoding", "pin-project", - "prost", - "socket2 0.5.10", + "socket2", + "sync_wrapper", "tokio", "tokio-stream", - "tower 0.4.13", + "tower 0.5.3", "tower-layer", "tower-service", "tracing", @@ -3340,9 +3385,32 @@ dependencies = [ [[package]] name = "tonic-build" -version = "0.12.3" +version = "0.14.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9557ce109ea773b399c9b9e5dca39294110b74f1f342cb347a80d1fce8c26a11" +checksum = "c68f61875ac5293cf72e6c8cf0158086428c82c37229e98c840878f1706b0322" +dependencies = [ + "prettyplease", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tonic-prost" +version = "0.14.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50849f68853be452acf590cde0b146665b8d507b3b8af17261df47e02c209ea0" +dependencies = [ + "bytes", + "prost", + "tonic", +] + +[[package]] +name = "tonic-prost-build" +version = "0.14.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "654e5643eff75d7f8c99197ce1440ed19a3474eada74c12bbac488b2cafdae27" dependencies = [ "prettyplease", "proc-macro2", @@ -3350,6 +3418,8 @@ dependencies = [ "prost-types", "quote", "syn", + "tempfile", + "tonic-build", ] [[package]] @@ -3360,13 +3430,8 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", - "indexmap 1.9.3", "pin-project", "pin-project-lite", - "rand 0.8.6", - "slab", - "tokio", - "tokio-util", "tower-layer", "tower-service", "tracing", @@ -3380,7 +3445,9 @@ checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" dependencies = [ "futures-core", "futures-util", + "indexmap", "pin-project-lite", + "slab", "sync_wrapper", "tokio", "tokio-util", @@ -3550,6 +3617,12 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" +[[package]] +name = "unicase" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" + [[package]] name = "unicode-bidi" version = "0.3.18" @@ -3766,7 +3839,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" dependencies = [ "anyhow", - "indexmap 2.14.0", + "indexmap", "wasm-encoder", "wasmparser", ] @@ -3779,7 +3852,7 @@ checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" dependencies = [ "bitflags", "hashbrown 0.15.5", - "indexmap 2.14.0", + "indexmap", "semver", ] @@ -3908,15 +3981,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "windows-sys" -version = "0.60.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" -dependencies = [ - "windows-targets 0.53.5", -] - [[package]] name = "windows-sys" version = "0.61.2" @@ -3950,30 +4014,13 @@ dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm 0.52.6", + "windows_i686_gnullvm", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", "windows_x86_64_msvc 0.52.6", ] -[[package]] -name = "windows-targets" -version = "0.53.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" -dependencies = [ - "windows-link", - "windows_aarch64_gnullvm 0.53.1", - "windows_aarch64_msvc 0.53.1", - "windows_i686_gnu 0.53.1", - "windows_i686_gnullvm 0.53.1", - "windows_i686_msvc 0.53.1", - "windows_x86_64_gnu 0.53.1", - "windows_x86_64_gnullvm 0.53.1", - "windows_x86_64_msvc 0.53.1", -] - [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -3986,12 +4033,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" - [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -4004,12 +4045,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" -[[package]] -name = "windows_aarch64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" - [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -4022,24 +4057,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" -[[package]] -name = "windows_i686_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" - [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" -[[package]] -name = "windows_i686_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" - [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -4052,12 +4075,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" -[[package]] -name = "windows_i686_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" - [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -4070,12 +4087,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" -[[package]] -name = "windows_x86_64_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" - [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -4088,12 +4099,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" - [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -4106,12 +4111,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" -[[package]] -name = "windows_x86_64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" - [[package]] name = "winnow" version = "0.7.15" @@ -4149,7 +4148,7 @@ checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" dependencies = [ "anyhow", "heck", - "indexmap 2.14.0", + "indexmap", "prettyplease", "syn", "wasm-metadata", @@ -4180,7 +4179,7 @@ checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" dependencies = [ "anyhow", "bitflags", - "indexmap 2.14.0", + "indexmap", "log", "serde", "serde_derive", @@ -4199,7 +4198,7 @@ checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" dependencies = [ "anyhow", "id-arena", - "indexmap 2.14.0", + "indexmap", "log", "semver", "serde", diff --git a/Cargo.toml b/Cargo.toml index 3c5fc9c..1e26306 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,14 +18,14 @@ tokio = { version = "1", features = ["full"] } tokio-stream = "0.1" # gRPC (client to spurctld) -tonic = "0.12" -prost = "0.13" -prost-types = "0.13" +tonic = "0.14" +prost = "0.14" +prost-types = "0.14" # REST API axum = { version = "0.7", features = ["json", "ws"] } axum-extra = { version = "0.9", features = ["typed-header", "cookie"] } -tower = "0.4" +tower = { version = "0.4", features = ["util"] } tower-http = { version = "0.5", features = ["cors", "trace"] } # CLI @@ -72,8 +72,8 @@ rand = "0.8.6" sha2 = "0.10" semver = "1" -# Spur proto — pinned to ROCm/spur release tag (see Cargo.lock for exact commit) -spur-proto = { git = "https://github.com/ROCm/spur.git", tag = "v0.3.0" } +# Spur proto — pinned to ROCm/spur commit (includes CDI device API from PR #262) +spur-proto = { git = "https://github.com/ROCm/spur.git", rev = "5781872920935f38b9b9c9734bbaefb29d9a6fb3" } # Internal spur-cloud-common = { path = "crates/spur-cloud-common" } diff --git a/crates/spur-cloud-api/src/spur_client.rs b/crates/spur-cloud-api/src/spur_client.rs index 083c27f..8d0789e 100644 --- a/crates/spur-cloud-api/src/spur_client.rs +++ b/crates/spur-cloud-api/src/spur_client.rs @@ -403,13 +403,16 @@ fn is_node_schedulable(state: NodeState) -> bool { pub async fn get_gpu_capacity( client: &mut SlurmControllerClient, ) -> anyhow::Result> { - use spur_cloud_common::gpu_types::{GpuNodeInfo, GpuPool}; let resp = client.get_nodes(GetNodesRequest::default()).await?; - let nodes = resp.into_inner().nodes; + Ok(compute_gpu_pools(&nodes)) +} + +fn compute_gpu_pools(nodes: &[NodeInfo]) -> Vec { + use spur_cloud_common::gpu_types::{GpuNodeInfo, GpuPool}; let mut pools: HashMap = HashMap::new(); - for node in &nodes { + for node in nodes { let total_res = node.total_resources.as_ref(); let alloc_res = node.alloc_resources.as_ref(); let node_state = node.state(); @@ -434,18 +437,28 @@ pub async fn get_gpu_capacity( } } - if let Some(alloc) = alloc_res { - for gpu in &alloc.gpus { - if let Some(pool) = pools.get_mut(&gpu.gpu_type) { - if schedulable { - pool.allocated += 1; + if let (Some(alloc), Some(total)) = (alloc_res, total_res) { + if let Some(devs) = alloc.devices.get("gpu") { + for dev in &devs.devices { + if let Some(gpu) = total.gpus.iter().find(|g| g.device_id == dev.device_id) { + if let Some(pool) = pools.get_mut(&gpu.gpu_type) { + if schedulable { + pool.allocated += dev.count as u32; + } + } else { + warn!( + node = %node.name, + gpu_type = %gpu.gpu_type, + "allocated GPU type not found in node's total resources - data inconsistency" + ); + } + } else { + warn!( + node = %node.name, + device_id = dev.device_id, + "allocated device_id not found in node's total GPU resources" + ); } - } else { - warn!( - node = %node.name, - gpu_type = %gpu.gpu_type, - "allocated GPU type not found in node's total resources - data inconsistency" - ); } } } @@ -459,8 +472,20 @@ pub async fn get_gpu_capacity( let mut alloc_counts: HashMap = HashMap::new(); if let Some(alloc) = alloc_res { - for gpu in &alloc.gpus { - *alloc_counts.entry(gpu.gpu_type.clone()).or_insert(0) += 1; + if let Some(devs) = alloc.devices.get("gpu") { + for dev in &devs.devices { + if let Some(gpu) = total.gpus.iter().find(|g| g.device_id == dev.device_id) + { + *alloc_counts.entry(gpu.gpu_type.clone()).or_insert(0) += + dev.count as u32; + } else { + warn!( + node = %node.name, + device_id = dev.device_id, + "per-node alloc: device_id not found in total GPU resources" + ); + } + } } } @@ -495,12 +520,202 @@ pub async fn get_gpu_capacity( pool.available = pool.total.saturating_sub(pool.allocated); } - Ok(pools.into_values().collect()) + pools.into_values().collect() } #[cfg(test)] mod tests { use super::*; + use spur_cloud_common::gpu_types::GpuPool; + + fn make_gpu(device_id: u32, gpu_type: &str, memory_mb: u64) -> GpuResource { + GpuResource { + device_id, + gpu_type: gpu_type.into(), + memory_mb, + peer_gpus: vec![], + link_type: 0, + } + } + + fn make_node_info( + name: &str, + state: i32, + gpus: Vec, + alloc_devices: Vec, + ) -> NodeInfo { + let alloc = if alloc_devices.is_empty() { + None + } else { + let mut devices = HashMap::new(); + devices.insert( + "gpu".to_string(), + DeviceAllocations { + devices: alloc_devices, + }, + ); + Some(ResourceAllocations { + cpus: 0, + memory_mb: 0, + devices, + }) + }; + NodeInfo { + name: name.into(), + state, + total_resources: Some(ResourceSet { + cpus: 32, + memory_mb: 128000, + gpus, + generic: Default::default(), + }), + alloc_resources: alloc, + ..Default::default() + } + } + + fn find_pool<'a>(pools: &'a [GpuPool], gpu_type: &str) -> &'a GpuPool { + pools.iter().find(|p| p.gpu_type == gpu_type).unwrap() + } + + #[test] + fn gpu_capacity_basic_idle_nodes() { + let nodes = vec![ + make_node_info( + "node1", + NodeState::NodeIdle as i32, + vec![make_gpu(0, "mi300x", 196608), make_gpu(1, "mi300x", 196608)], + vec![], + ), + make_node_info( + "node2", + NodeState::NodeIdle as i32, + vec![make_gpu(0, "mi300x", 196608), make_gpu(1, "mi300x", 196608)], + vec![], + ), + ]; + let pools = compute_gpu_pools(&nodes); + assert_eq!(pools.len(), 1); + let pool = find_pool(&pools, "mi300x"); + assert_eq!(pool.total, 4); + assert_eq!(pool.allocated, 0); + assert_eq!(pool.available, 4); + assert_eq!(pool.memory_mb, 196608); + assert_eq!(pool.nodes.len(), 2); + } + + #[test] + fn gpu_capacity_with_allocations() { + let nodes = vec![make_node_info( + "node1", + NodeState::NodeMixed as i32, + vec![ + make_gpu(0, "mi300x", 196608), + make_gpu(1, "mi300x", 196608), + make_gpu(2, "mi300x", 196608), + make_gpu(3, "mi300x", 196608), + ], + vec![ + AllocatedDevice { + device_id: 0, + count: 1, + }, + AllocatedDevice { + device_id: 2, + count: 1, + }, + ], + )]; + let pools = compute_gpu_pools(&nodes); + let pool = find_pool(&pools, "mi300x"); + assert_eq!(pool.total, 4); + assert_eq!(pool.allocated, 2); + assert_eq!(pool.available, 2); + assert_eq!(pool.nodes[0].available_gpus, 2); + } + + #[test] + fn gpu_capacity_non_schedulable_excluded() { + let nodes = vec![make_node_info( + "down-node", + NodeState::NodeDown as i32, + vec![make_gpu(0, "mi300x", 196608), make_gpu(1, "mi300x", 196608)], + vec![AllocatedDevice { + device_id: 0, + count: 1, + }], + )]; + let pools = compute_gpu_pools(&nodes); + let pool = find_pool(&pools, "mi300x"); + // Down nodes don't count toward total or allocated + assert_eq!(pool.total, 0); + assert_eq!(pool.allocated, 0); + assert_eq!(pool.available, 0); + // But still appears in node list with 0 available + assert_eq!(pool.nodes[0].available_gpus, 0); + } + + #[test] + fn gpu_capacity_multiple_types() { + let nodes = vec![make_node_info( + "mixed-node", + NodeState::NodeIdle as i32, + vec![ + make_gpu(0, "mi300x", 196608), + make_gpu(1, "mi300x", 196608), + make_gpu(2, "h100", 81920), + make_gpu(3, "h100", 81920), + ], + vec![AllocatedDevice { + device_id: 2, + count: 1, + }], + )]; + let pools = compute_gpu_pools(&nodes); + assert_eq!(pools.len(), 2); + let mi = find_pool(&pools, "mi300x"); + assert_eq!(mi.total, 2); + assert_eq!(mi.allocated, 0); + let h1 = find_pool(&pools, "h100"); + assert_eq!(h1.total, 2); + assert_eq!(h1.allocated, 1); + assert_eq!(h1.available, 1); + } + + #[test] + fn gpu_capacity_count_greater_than_one() { + let nodes = vec![make_node_info( + "node1", + NodeState::NodeAllocated as i32, + vec![make_gpu(0, "mi300x", 196608), make_gpu(1, "mi300x", 196608)], + vec![AllocatedDevice { + device_id: 0, + count: 2, + }], + )]; + let pools = compute_gpu_pools(&nodes); + let pool = find_pool(&pools, "mi300x"); + assert_eq!(pool.allocated, 2); + } + + #[test] + fn gpu_capacity_unresolved_device_id_skipped() { + // device_id 99 doesn't exist in total — should be silently skipped (with warning) + let nodes = vec![make_node_info( + "node1", + NodeState::NodeIdle as i32, + vec![make_gpu(0, "mi300x", 196608)], + vec![AllocatedDevice { + device_id: 99, + count: 1, + }], + )]; + let pools = compute_gpu_pools(&nodes); + let pool = find_pool(&pools, "mi300x"); + assert_eq!(pool.total, 1); + assert_eq!(pool.allocated, 0); + assert_eq!(pool.available, 1); + } #[test] fn schedulable_states() { From 9b3c94b2f57944aa78842ebfd0942f511cd11e2b Mon Sep 17 00:00:00 2001 From: Shiv Tyagi Date: Fri, 12 Jun 2026 05:54:47 +0000 Subject: [PATCH 2/2] fix(spur-cloud-api): address review feedback on GPU pool accounting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Eliminate double warning on unresolved device_id by resolving allocations through a per-node id_to_type lookup in a single pass. Warn when a node has alloc_resources but no total_resources instead of silently dropping. Use saturating arithmetic for u64→u32 count conversion and pool accumulation. Add tests for alloc-without-total and multi-node tally consistency. --- crates/spur-cloud-api/src/spur_client.rs | 238 ++++++++++++++++------- 1 file changed, 163 insertions(+), 75 deletions(-) diff --git a/crates/spur-cloud-api/src/spur_client.rs b/crates/spur-cloud-api/src/spur_client.rs index 8d0789e..095d35b 100644 --- a/crates/spur-cloud-api/src/spur_client.rs +++ b/crates/spur-cloud-api/src/spur_client.rs @@ -418,40 +418,53 @@ fn compute_gpu_pools(nodes: &[NodeInfo]) -> Vec = total + .gpus + .iter() + .map(|g| (g.device_id, g.gpu_type.as_str())) + .collect(); + + for gpu in &total.gpus { + let pool = pools + .entry(gpu.gpu_type.clone()) + .or_insert_with(|| GpuPool { + gpu_type: gpu.gpu_type.clone(), + total: 0, + available: 0, + allocated: 0, + memory_mb: gpu.memory_mb, + nodes: Vec::new(), + }); + if schedulable { + pool.total += 1; } } - if let (Some(alloc), Some(total)) = (alloc_res, total_res) { + // Resolve allocations using the shared lookup. + let mut alloc_counts: HashMap = HashMap::new(); + if let Some(alloc) = alloc_res { if let Some(devs) = alloc.devices.get("gpu") { for dev in &devs.devices { - if let Some(gpu) = total.gpus.iter().find(|g| g.device_id == dev.device_id) { - if let Some(pool) = pools.get_mut(&gpu.gpu_type) { + let count = u32::try_from(dev.count).unwrap_or(u32::MAX); + if let Some(&gpu_type) = id_to_type.get(&dev.device_id) { + if let Some(pool) = pools.get_mut(gpu_type) { if schedulable { - pool.allocated += dev.count as u32; + pool.allocated = pool.allocated.saturating_add(count); } - } else { - warn!( - node = %node.name, - gpu_type = %gpu.gpu_type, - "allocated GPU type not found in node's total resources - data inconsistency" - ); } + let entry = alloc_counts.entry(gpu_type.to_string()).or_insert(0); + *entry = entry.saturating_add(count); } else { warn!( node = %node.name, @@ -463,59 +476,35 @@ fn compute_gpu_pools(nodes: &[NodeInfo]) -> Vec = HashMap::new(); - for gpu in &total.gpus { - *gpu_counts.entry(gpu.gpu_type.clone()).or_insert(0) += 1; - } - - let mut alloc_counts: HashMap = HashMap::new(); - if let Some(alloc) = alloc_res { - if let Some(devs) = alloc.devices.get("gpu") { - for dev in &devs.devices { - if let Some(gpu) = total.gpus.iter().find(|g| g.device_id == dev.device_id) - { - *alloc_counts.entry(gpu.gpu_type.clone()).or_insert(0) += - dev.count as u32; - } else { - warn!( - node = %node.name, - device_id = dev.device_id, - "per-node alloc: device_id not found in total GPU resources" - ); - } - } - } - } + // Build per-node info — always include for visibility, even non-schedulable nodes. + let mut gpu_counts: HashMap = HashMap::new(); + for gpu in &total.gpus { + *gpu_counts.entry(gpu.gpu_type.clone()).or_insert(0) += 1; + } - for (gpu_type, total_count) in gpu_counts { - // Ensure the pool exists even for non-schedulable nodes (for node info) - let pool = pools.entry(gpu_type.clone()).or_insert_with(|| GpuPool { - gpu_type: gpu_type.clone(), - total: 0, - available: 0, - allocated: 0, - memory_mb: 0, - nodes: Vec::new(), - }); - let alloc_count = alloc_counts.get(&gpu_type).copied().unwrap_or(0); - pool.nodes.push(GpuNodeInfo { - name: node.name.clone(), - total_gpus: total_count, - // Non-schedulable nodes show 0 available - available_gpus: if schedulable { - total_count.saturating_sub(alloc_count) - } else { - 0 - }, - state: format!("{:?}", node_state), - }); - } + for (gpu_type, total_count) in gpu_counts { + let pool = pools.entry(gpu_type.clone()).or_insert_with(|| GpuPool { + gpu_type: gpu_type.clone(), + total: 0, + available: 0, + allocated: 0, + memory_mb: 0, + nodes: Vec::new(), + }); + let alloc_count = alloc_counts.get(&gpu_type).copied().unwrap_or(0); + pool.nodes.push(GpuNodeInfo { + name: node.name.clone(), + total_gpus: total_count, + available_gpus: if schedulable { + total_count.saturating_sub(alloc_count) + } else { + 0 + }, + state: format!("{:?}", node_state), + }); } } - // Compute available = total - allocated (both already filtered to schedulable nodes) for pool in pools.values_mut() { pool.available = pool.total.saturating_sub(pool.allocated); } @@ -682,6 +671,9 @@ mod tests { assert_eq!(h1.available, 1); } + // count > 1 per device_id doesn't happen today (GPUs use injectable alloc with + // count=1), but spur-cloud shouldn't assume that — future time-slicing or MIG + // support could advertise multiple units per physical device. #[test] fn gpu_capacity_count_greater_than_one() { let nodes = vec![make_node_info( @@ -717,6 +709,102 @@ mod tests { assert_eq!(pool.available, 1); } + #[test] + fn gpu_capacity_alloc_without_total_warns_and_skips() { + let mut devices = HashMap::new(); + devices.insert( + "gpu".to_string(), + DeviceAllocations { + devices: vec![AllocatedDevice { + device_id: 0, + count: 1, + }], + }, + ); + let nodes = vec![NodeInfo { + name: "orphan-node".into(), + state: NodeState::NodeMixed as i32, + total_resources: None, + alloc_resources: Some(ResourceAllocations { + cpus: 4, + memory_mb: 8192, + devices, + }), + ..Default::default() + }]; + let pools = compute_gpu_pools(&nodes); + assert!(pools.is_empty()); + } + + #[test] + fn gpu_capacity_multi_node_tallies_consistent() { + let nodes = vec![ + make_node_info( + "node1", + NodeState::NodeMixed as i32, + vec![ + make_gpu(0, "mi300x", 196608), + make_gpu(1, "mi300x", 196608), + make_gpu(2, "mi300x", 196608), + make_gpu(3, "mi300x", 196608), + ], + vec![ + AllocatedDevice { + device_id: 0, + count: 1, + }, + AllocatedDevice { + device_id: 1, + count: 1, + }, + ], + ), + make_node_info( + "node2", + NodeState::NodeAllocated as i32, + vec![make_gpu(0, "mi300x", 196608), make_gpu(1, "mi300x", 196608)], + vec![ + AllocatedDevice { + device_id: 0, + count: 1, + }, + AllocatedDevice { + device_id: 1, + count: 1, + }, + ], + ), + make_node_info( + "node3", + NodeState::NodeIdle as i32, + vec![make_gpu(0, "mi300x", 196608), make_gpu(1, "mi300x", 196608)], + vec![], + ), + ]; + let pools = compute_gpu_pools(&nodes); + assert_eq!(pools.len(), 1); + let pool = find_pool(&pools, "mi300x"); + assert_eq!(pool.total, 8); + assert_eq!(pool.allocated, 4); + assert_eq!(pool.available, 4); + assert_eq!(pool.nodes.len(), 3); + + let per_node: HashMap<&str, &spur_cloud_common::gpu_types::GpuNodeInfo> = + pool.nodes.iter().map(|n| (n.name.as_str(), n)).collect(); + assert_eq!(per_node["node1"].total_gpus, 4); + assert_eq!(per_node["node1"].available_gpus, 2); + assert_eq!(per_node["node2"].total_gpus, 2); + assert_eq!(per_node["node2"].available_gpus, 0); + assert_eq!(per_node["node3"].total_gpus, 2); + assert_eq!(per_node["node3"].available_gpus, 2); + + // Global tallies must equal sum of per-node tallies + let sum_total: u32 = pool.nodes.iter().map(|n| n.total_gpus).sum(); + let sum_available: u32 = pool.nodes.iter().map(|n| n.available_gpus).sum(); + assert_eq!(pool.total, sum_total); + assert_eq!(pool.available, sum_available); + } + #[test] fn schedulable_states() { assert!(is_node_schedulable(NodeState::NodeIdle));