diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index d3f51dfbe2..8ecb9d026b 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -340,6 +340,20 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_SHUFFLE_DIRECT_NATIVE_ENABLED: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.directNative.enabled") + .category(CATEGORY_SHUFFLE) + .doc( + "When enabled, the native shuffle writer will directly execute the child native plan " + + "instead of reading intermediate batches via JNI. This optimization avoids the " + + "JNI round-trip for native plans whose inputs are all native scans " + + "(CometNativeScanExec, CometIcebergNativeScanExec). Supports single and multi-source " + + "plans (e.g., joins over native scans). " + + "This is an experimental feature and is disabled by default.") + .internal() + .booleanConf + .createWithDefault(false) + val COMET_SHUFFLE_DIRECT_READ_ENABLED: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.directRead.enabled") .category(CATEGORY_SHUFFLE) diff --git a/native/Cargo.lock b/native/Cargo.lock index ae2d6b074c..5b69db69b0 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -230,9 +230,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d441fdda254b65f3e9025910eb2c2066b6295d9c8ed409522b8d2ace1ff8574c" +checksum = "607e64bb911ee4f90483e044fe78f175989148c2892e659a2cd25429e782ec54" dependencies = [ "arrow-arith", "arrow-array", @@ -251,9 +251,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced5406f8b720cc0bc3aa9cf5758f93e8593cda5490677aa194e4b4b383f9a59" +checksum = "e754319ed8a85d817fe7adf183227e0b5308b82790a737b426c1124626b48118" dependencies = [ "arrow-array", "arrow-buffer", @@ -265,9 +265,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "772bd34cacdda8baec9418d80d23d0fb4d50ef0735685bd45158b83dfeb6e62d" +checksum = "841321891f247aa86c6112c80d83d89cb36e0addd020fa2425085b8eb6c3f579" dependencies = [ "ahash", "arrow-buffer", @@ -276,7 +276,7 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.16.1", + "hashbrown 0.17.0", "num-complex", "num-integer", "num-traits", @@ -284,9 +284,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "898f4cf1e9598fdb77f356fdf2134feedfd0ee8d5a4e0a5f573e7d0aec16baa4" +checksum = "f955dfb73fae000425f49c8226d2044dab60fb7ad4af1e24f961756354d996c9" dependencies = [ "bytes", "half", @@ -296,9 +296,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0127816c96533d20fc938729f48c52d3e48f99717e7a0b5ade77d742510736d" +checksum = "ca5e686972523798f76bef355145bc1ae25a84c731e650268d31ab763c701663" dependencies = [ "arrow-array", "arrow-buffer", @@ -318,9 +318,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca025bd0f38eeecb57c2153c0123b960494138e6a957bbda10da2b25415209fe" +checksum = "86c276756867fc8186ec380c72c290e6e3b23a1d4fb05df6b1d62d2e62666d48" dependencies = [ "arrow-array", "arrow-cast", @@ -333,9 +333,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42d10beeab2b1c3bb0b53a00f7c944a178b622173a5c7bcabc3cb45d90238df4" +checksum = "db3b5846209775b6dc8056d77ff9a032b27043383dd5488abd0b663e265b9373" dependencies = [ "arrow-buffer", "arrow-schema", @@ -346,9 +346,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "609a441080e338147a84e8e6904b6da482cefb957c5cdc0f3398872f69a315d0" +checksum = "fd8907ddd8f9fbabf91ec2c85c1d81fe2874e336d2443eb36373595e28b98dd5" dependencies = [ "arrow-array", "arrow-buffer", @@ -361,15 +361,16 @@ dependencies = [ [[package]] name = "arrow-json" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ead0914e4861a531be48fe05858265cf854a4880b9ed12618b1d08cba9bebc8" +checksum = "f4518c59acc501f10d7dcae397fe12b8db3d81bc7de94456f8a58f9165d6f502" dependencies = [ "arrow-array", "arrow-buffer", "arrow-cast", - "arrow-data", + "arrow-ord", "arrow-schema", + "arrow-select", "chrono", "half", "indexmap 2.14.0", @@ -385,9 +386,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "763a7ba279b20b52dad300e68cfc37c17efa65e68623169076855b3a9e941ca5" +checksum = "efa70d9d6b1356f1fb9f1f651b84a725b7e0abb93f188cf7d31f14abfa2f2e6f" dependencies = [ "arrow-array", "arrow-buffer", @@ -398,9 +399,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e14fe367802f16d7668163ff647830258e6e0aeea9a4d79aaedf273af3bdcd3e" +checksum = "faec88a945338192beffbbd4be0def70135422930caa244ac3cec0cd213b26b4" dependencies = [ "arrow-array", "arrow-buffer", @@ -411,9 +412,9 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c30a1365d7a7dc50cc847e54154e6af49e4c4b0fddc9f607b687f29212082743" +checksum = "18aa020f6bc8e5201dcd2d4b7f98c68f8a410ef37128263243e6ff2a47a67d4f" dependencies = [ "bitflags 2.11.1", "serde_core", @@ -422,9 +423,9 @@ dependencies = [ [[package]] name = "arrow-select" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78694888660a9e8ac949853db393af2a8b8fc82c19ce333132dfa2e72cc1a7fe" +checksum = "a657ab5132e9c8ca3b24eb15a823d0ced38017fe3930ff50167466b02e2d592c" dependencies = [ "ahash", "arrow-array", @@ -436,9 +437,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e04a01f8bb73ce54437514c5fd3ee2aa3e8abe4c777ee5cc55853b1652f79e" +checksum = "f6de2efbbd1a9f9780ceb8d1ff5d20421b35863b361e3386b4f571f1fc69fcb8" dependencies = [ "arrow-array", "arrow-buffer", @@ -492,9 +493,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.41" +version = "0.4.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0f9ee0f6e02ffd7ad5816e9464499fba7b3effd01123b515c41d1697c43dad1" +checksum = "e79b3f8a79cccc2898f31920fc69f304859b3bd567490f75ebf51ae1c792a9ac" dependencies = [ "compression-codecs", "compression-core", @@ -1125,9 +1126,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.8.4" +version = "1.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d2d5991425dfd0785aed03aedcf0b321d61975c9b5b3689c774a2610ae0b51e" +checksum = "0aa83c34e62843d924f905e0f5c866eb1dd6545fc4d719e803d9ba6030371fce" dependencies = [ "arrayref", "arrayvec", @@ -1496,9 +1497,9 @@ dependencies = [ [[package]] name = "compression-codecs" -version = "0.4.37" +version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb7b51a7d9c967fc26773061ba86150f19c50c0d65c887cb1fbe295fd16619b7" +checksum = "ce2548391e9c1929c21bf6aa2680af86fe4c1b33e6cea9ac1cfeec0bd11218cf" dependencies = [ "bzip2", "compression-core", @@ -1511,9 +1512,9 @@ dependencies = [ [[package]] name = "compression-core" -version = "0.4.31" +version = "0.4.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75984efb6ed102a0d42db99afb6c1948f0380d1d91808d5529916e6c08b49d8d" +checksum = "cc14f565cf027a105f7a44ccf9e5b424348421a1d8952a8fc9d499d313107789" [[package]] name = "concurrent-queue" @@ -1580,9 +1581,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpp_demangle" -version = "0.5.1" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0667304c32ea56cb4cd6d2d7c0cfe9a2f8041229db8c033af7f8d69492429def" +checksum = "f2bb79cb74d735044c972aae58ed0aaa9a837e85b01106a54c39e42e97f62253" dependencies = [ "cfg-if", ] @@ -2813,9 +2814,9 @@ dependencies = [ [[package]] name = "digest" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4850db49bf08e663084f7fb5c87d202ef91a3907271aff24a94eb97ff039153c" +checksum = "f1dd6dbb5841937940781866fa1281a1ff7bd3bf827091440879f9994983d5c2" dependencies = [ "block-buffer 0.12.0", "const-oid 0.10.2", @@ -3264,9 +3265,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +checksum = "171fefbc92fe4a4de27e0698d6a5b392d6a0e333506bc49133760b3bcf948733" dependencies = [ "atomic-waker", "bytes", @@ -3388,7 +3389,7 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6303bc9732ae41b04cb554b844a762b4115a61bfaa81e3e83050991eeb56863f" dependencies = [ - "digest 0.11.2", + "digest 0.11.3", ] [[package]] @@ -3469,9 +3470,9 @@ checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" [[package]] name = "hybrid-array" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3944cf8cf766b40e2a1a333ee5e9b563f854d5fa49d6a8ca2764e97c6eddb214" +checksum = "08d46837a0ed51fe95bd3b05de33cd64a1ee88fc797477ca48446872504507c5" dependencies = [ "typenum", ] @@ -3742,9 +3743,9 @@ dependencies = [ [[package]] name = "idna_adapter" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +checksum = "cb68373c0d6620ef8105e855e7745e18b0d00d3bdb07fb532e434244cdb9a714" dependencies = [ "icu_normalizer", "icu_properties", @@ -3822,16 +3823,6 @@ version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" -[[package]] -name = "iri-string" -version = "0.7.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25e659a4bb38e810ebc252e53b5814ff908a8c58c2a9ce2fae1bbec24cbf4e20" -dependencies = [ - "memchr", - "serde", -] - [[package]] name = "is-terminal" version = "0.4.17" @@ -3884,9 +3875,9 @@ dependencies = [ [[package]] name = "jiff" -version = "0.2.23" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a3546dc96b6d42c5f24902af9e2538e82e39ad350b0c766eb3fbf2d8f3d8359" +checksum = "f00b5dbd620d61dfdcb6007c9c1f6054ebd75319f163d886a9055cec1155073d" dependencies = [ "jiff-static", "jiff-tzdb-platform", @@ -3901,9 +3892,9 @@ dependencies = [ [[package]] name = "jiff-static" -version = "0.2.23" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" +checksum = "e000de030ff8022ea1da3f466fbb0f3a809f5e51ed31f6dd931c35181ad8e6d7" dependencies = [ "proc-macro2", "quote", @@ -4013,9 +4004,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.95" +version = "0.3.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2964e92d1d9dc3364cae4d718d93f227e3abb088e747d92e0395bfdedf1c12ca" +checksum = "a1840c94c045fbcf8ba2812c95db44499f7c64910a912551aaaa541decebcacf" dependencies = [ "cfg-if", "futures-util", @@ -4537,7 +4528,7 @@ dependencies = [ "md-5", "parking_lot", "percent-encoding", - "quick-xml 0.39.2", + "quick-xml 0.39.3", "rand 0.10.1", "reqwest 0.12.28", "ring", @@ -4658,7 +4649,7 @@ dependencies = [ "percent-encoding", "quick-xml 0.38.4", "reqsign-core", - "reqwest 0.13.2", + "reqwest 0.13.3", "serde", "serde_json", "tokio", @@ -4796,9 +4787,9 @@ dependencies = [ [[package]] name = "parquet" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d3f9f2205199603564127932b89695f52b62322f541d0fc7179d57c2e1c9877" +checksum = "43d7efd3052f7d6ef601085559a246bc991e9a8cc77e02753737df6322ce35f1" dependencies = [ "ahash", "arrow-array", @@ -4814,7 +4805,7 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.16.1", + "hashbrown 0.17.0", "lz4_flex", "num-bigint", "num-integer", @@ -4836,23 +4827,25 @@ dependencies = [ [[package]] name = "parquet-variant" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bf493f3c9ddd984d0efb019f67343e4aa4bab893931f6a14b82083065dc3d28" +checksum = "262fd51760f388670dbab2283efaadd0f4ed87ad584e60bd0db7fb79d527f045" dependencies = [ + "arrow", "arrow-schema", "chrono", "half", "indexmap 2.14.0", + "num-traits", "simdutf8", "uuid", ] [[package]] name = "parquet-variant-compute" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ac038d46a503a7d563b4f5df5802c4315d5343d009feab195d15ac512b4cb27" +checksum = "4c94fc2c2c077a00b3d5232f965037cee3455c432567a78d66db101daa035689" dependencies = [ "arrow", "arrow-schema", @@ -4867,9 +4860,9 @@ dependencies = [ [[package]] name = "parquet-variant-json" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "015a09c2ffe5108766c7c1235c307b8a3c2ea64eca38455ba1a7f3a7f32f16e2" +checksum = "7ed1077da4aeb4e4141aa2f9858ac354975595eb30f907762894587941e8f2f7" dependencies = [ "arrow-schema", "base64", @@ -4958,18 +4951,18 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.11" +version = "1.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517" +checksum = "cbf0d9e68100b3a7989b4901972f265cd542e560a3a8a724e1e20322f4d06ce9" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.11" +version = "1.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" +checksum = "a990e22f43e84855daf260dded30524ef4a9021cc7541c26540500a50b624389" dependencies = [ "proc-macro2", "quote", @@ -5289,9 +5282,9 @@ dependencies = [ [[package]] name = "quick-xml" -version = "0.39.2" +version = "0.39.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "958f21e8e7ceb5a1aa7fa87fab28e7c75976e0bfe7e23ff069e0a260f894067d" +checksum = "721da970c312655cde9b4ffe0547f20a8494866a4af5ff51f18b7c633d0c870b" dependencies = [ "memchr", "serde", @@ -5633,9 +5626,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.13.2" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801" +checksum = "62e0021ea2c22aed41653bc7e1419abb2c97e038ff2c33d0e1309e49a97deec0" dependencies = [ "base64", "bytes", @@ -5694,9 +5687,9 @@ dependencies = [ [[package]] name = "roaring" -version = "0.11.3" +version = "0.11.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ba9ce64a8f45d7fc86358410bb1a82e8c987504c0d4900e9141d69a9f26c885" +checksum = "1dedc5658c6ecb3bdb5ef5f3295bb9253f42dcf3fd1402c03f6b1f7659c3c4a9" dependencies = [ "bytemuck", "byteorder", @@ -5788,9 +5781,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.38" +version = "0.23.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69f9466fb2c14ea04357e91413efb882e2a6d4a406e625449bc0a5d360d53a21" +checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" dependencies = [ "aws-lc-rs", "once_cell", @@ -5815,9 +5808,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.14.0" +version = "1.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9" dependencies = [ "web-time", "zeroize", @@ -5825,13 +5818,13 @@ dependencies = [ [[package]] name = "rustls-platform-verifier" -version = "0.6.2" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" +checksum = "26d1e2536ce4f35f4846aa13bff16bd0ff40157cdb14cc056c7b14ba41233ba0" dependencies = [ "core-foundation", "core-foundation-sys", - "jni 0.21.1", + "jni 0.22.4", "log", "once_cell", "rustls", @@ -5852,9 +5845,9 @@ checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" [[package]] name = "rustls-webpki" -version = "0.103.12" +version = "0.103.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8279bb85272c9f10811ae6a6c547ff594d6a7f3c6c6b02ee9726d1d0dcfcdd06" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" dependencies = [ "aws-lc-rs", "ring", @@ -6075,9 +6068,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.18.0" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd5414fad8e6907dbdd5bc441a50ae8d6e26151a03b1de04d89a5576de61d01f" +checksum = "f05839ce67618e14a09b286535c0d9c94e85ef25469b0e13cb4f844e5593eb19" dependencies = [ "base64", "chrono", @@ -6094,9 +6087,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.18.0" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3db8978e608f1fe7357e211969fd9abdcae80bac1ba7a3369bb7eb6b404eb65" +checksum = "cf2ebbe86054f9b45bc3881e865683ccfaccce97b9b4cb53f3039d67f355a334" dependencies = [ "darling 0.23.0", "proc-macro2", @@ -6147,7 +6140,7 @@ checksum = "446ba717509524cb3f22f17ecc096f10f4822d76ab5c0b9822c5f9c284e825f4" dependencies = [ "cfg-if", "cpufeatures 0.3.0", - "digest 0.11.2", + "digest 0.11.3", ] [[package]] @@ -6212,9 +6205,9 @@ dependencies = [ [[package]] name = "siphasher" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" +checksum = "8ee5873ec9cce0195efcb7a4e9507a04cd49aec9c83d0389df45b1ef7ba2e649" [[package]] name = "slab" @@ -6337,9 +6330,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "symbolic-common" -version = "12.18.1" +version = "12.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f3cdeaae6779ecba2567f20bf7716718b8c4ce6717c9def4ced18786bb11ea" +checksum = "332615d90111d8eeaf86a84dc9bbe9f65d0d8c5cf11b4caccedc37754eb0dcfd" dependencies = [ "debugid", "memmap2", @@ -6349,9 +6342,9 @@ dependencies = [ [[package]] name = "symbolic-demangle" -version = "12.18.1" +version = "12.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "672c6ad9cb8fce6a1283cc9df9070073cccad00ae241b80e3686328a64e3523b" +checksum = "912017718eb4d21930546245af9a3475c9dccf15675a5c215664e76621afc471" dependencies = [ "cpp_demangle", "rustc-demangle", @@ -6588,9 +6581,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.52.1" +version = "1.52.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6" +checksum = "110a78583f19d5cdb2c5ccf321d1290344e71313c6c37d43520d386027d18386" dependencies = [ "bytes", "libc", @@ -6666,20 +6659,20 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.8" +version = "0.6.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +checksum = "68d6fdd9f81c2819c9a8b0e0cd91660e7746a8e6ea2ba7c6b2b057985f6bcb51" dependencies = [ "bitflags 2.11.1", "bytes", "futures-util", "http 1.4.0", "http-body 1.0.1", - "iri-string", "pin-project-lite", "tower", "tower-layer", "tower-service", + "url", ] [[package]] @@ -6777,9 +6770,9 @@ dependencies = [ [[package]] name = "typenum" -version = "1.19.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" [[package]] name = "typetag" @@ -6947,11 +6940,11 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasip2" -version = "1.0.2+wasi-0.2.9" +version = "1.0.3+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.57.1", ] [[package]] @@ -6960,14 +6953,14 @@ version = "0.4.0+wasi-0.3.0-rc-2026-01-06" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.51.0", ] [[package]] name = "wasm-bindgen" -version = "0.2.118" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf938a0bacb0469e83c1e148908bd7d5a6010354cf4fb73279b7447422e3a89" +checksum = "df52b6d9b87e0c74c9edfa1eb2d9bf85e5d63515474513aa50fa181b3c4f5db1" dependencies = [ "cfg-if", "once_cell", @@ -6978,9 +6971,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.68" +version = "0.4.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f371d383f2fb139252e0bfac3b81b265689bf45b6874af544ffa4c975ac1ebf8" +checksum = "af934872acec734c2d80e6617bbb5ff4f12b052dd8e6332b0817bce889516084" dependencies = [ "js-sys", "wasm-bindgen", @@ -6988,9 +6981,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.118" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eeff24f84126c0ec2db7a449f0c2ec963c6a49efe0698c4242929da037ca28ed" +checksum = "78b1041f495fb322e64aca85f5756b2172e35cd459376e67f2a6c9dffcedb103" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6998,9 +6991,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.118" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d08065faf983b2b80a79fd87d8254c409281cf7de75fc4b773019824196c904" +checksum = "9dcd0ff20416988a18ac686d4d4d0f6aae9ebf08a389ff5d29012b05af2a1b41" dependencies = [ "bumpalo", "proc-macro2", @@ -7011,9 +7004,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.118" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fd04d9e306f1907bd13c6361b5c6bfc7b3b3c095ed3f8a9246390f8dbdee129" +checksum = "49757b3c82ebf16c57d69365a142940b384176c24df52a087fb748e2085359ea" dependencies = [ "unicode-ident", ] @@ -7080,9 +7073,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.95" +version = "0.3.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f2dfbb17949fa2088e5d39408c48368947b86f7834484e87b73de55bc14d97d" +checksum = "2eadbac71025cd7b0834f20d1fe8472e8495821b4e9801eb0a60bd1f19827602" dependencies = [ "js-sys", "wasm-bindgen", @@ -7458,6 +7451,12 @@ dependencies = [ "wit-bindgen-rust-macro", ] +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + [[package]] name = "wit-bindgen-core" version = "0.51.0" diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index f27d021ac4..75f175cca3 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -32,17 +32,26 @@ import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsR import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} -import org.apache.spark.sql.comet.{CometExec, CometMetricNode} +import org.apache.spark.sql.comet.{CometExec, CometMetricNode, PlanDataInjector} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.CometConf +import org.apache.comet.{CometConf, CometExecIterator} import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde} import org.apache.comet.serde.OperatorOuterClass.{CompressionCodec, Operator} import org.apache.comet.serde.QueryPlanSerde.serializeDataType /** * A [[ShuffleWriter]] that will delegate shuffle write to native shuffle. + * + * @param childNativePlan + * When provided, the shuffle writer will execute this native plan directly and pipe its output + * to the ShuffleWriter, avoiding the JNI round-trip for intermediate batches. Used when all + * input sources are native scans (CometNativeScanExec, CometIcebergNativeScanExec). + * @param commonByKey + * Common planning data (schemas, filters) keyed by source identifier, for PlanDataInjector. + * @param perPartitionByKey + * Per-partition planning data (file lists) keyed by source identifier, for PlanDataInjector. */ class CometNativeShuffleWriter[K, V]( outputPartitioning: Partitioning, @@ -53,7 +62,10 @@ class CometNativeShuffleWriter[K, V]( mapId: Long, context: TaskContext, metricsReporter: ShuffleWriteMetricsReporter, - rangePartitionBounds: Option[Seq[InternalRow]] = None) + rangePartitionBounds: Option[Seq[InternalRow]] = None, + childNativePlan: Option[Operator] = None, + commonByKey: Map[String, Array[Byte]] = Map.empty, + perPartitionByKey: Map[String, Array[Array[Byte]]] = Map.empty) extends ShuffleWriter[K, V] with Logging { @@ -75,6 +87,18 @@ class CometNativeShuffleWriter[K, V]( // Call native shuffle write val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename) + // Inject per-partition file data if this is a direct native execution plan + val actualPlan = if (commonByKey.nonEmpty && perPartitionByKey.nonEmpty) { + val partitionIdx = context.partitionId() + val partitionByKey = perPartitionByKey.map { case (key, arr) => + key -> arr(partitionIdx) + } + val injected = PlanDataInjector.injectPlanData(nativePlan, commonByKey, partitionByKey) + CometExec.serializeNativePlan(injected) + } else { + CometExec.serializeNativePlan(nativePlan) + } + val detailedMetrics = Seq( "elapsed_compute", "encode_time", @@ -96,15 +120,14 @@ class CometNativeShuffleWriter[K, V]( // Getting rid of the fake partitionId val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2) - val cometIter = CometExec.getCometIterator( + val cometIter = new CometExecIterator( + CometExec.newIterId, Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), outputAttributes.length, - nativePlan, + actualPlan, nativeMetrics, numParts, - context.partitionId(), - broadcastedHadoopConfForEncryption = None, - encryptedFilePaths = Seq.empty) + context.partitionId()) while (cometIter.hasNext) { cometIter.next() @@ -163,162 +186,162 @@ class CometNativeShuffleWriter[K, V]( } private def getNativePlan(dataFile: String, indexFile: String): Operator = { - val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") - val opBuilder = OperatorOuterClass.Operator.newBuilder() - - val scanTypes = outputAttributes.flatten { attr => - serializeDataType(attr.dataType) - } - - if (scanTypes.length == outputAttributes.length) { + // When childNativePlan is provided, we use it directly as the input to ShuffleWriter. + // Otherwise, we create a Scan operator that reads from JNI input ("ShuffleWriterInput"). + val inputOperator: Operator = childNativePlan.getOrElse { + val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") + val scanTypes = outputAttributes.flatten { attr => + serializeDataType(attr.dataType) + } + if (scanTypes.length != outputAttributes.length) { + throw new UnsupportedOperationException( + s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.") + } scanBuilder.addAllFields(scanTypes.asJava) + OperatorOuterClass.Operator.newBuilder().setScan(scanBuilder).build() + } - val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() - shuffleWriterBuilder.setOutputDataFile(dataFile) - shuffleWriterBuilder.setOutputIndexFile(indexFile) + val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() + shuffleWriterBuilder.setOutputDataFile(dataFile) + shuffleWriterBuilder.setOutputIndexFile(indexFile) - if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { - val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { - case "zstd" => CompressionCodec.Zstd - case "lz4" => CompressionCodec.Lz4 - case "snappy" => CompressionCodec.Snappy - case other => throw new UnsupportedOperationException(s"invalid codec: $other") - } - shuffleWriterBuilder.setCodec(codec) - } else { - shuffleWriterBuilder.setCodec(CompressionCodec.None) + if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { + val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { + case "zstd" => CompressionCodec.Zstd + case "lz4" => CompressionCodec.Lz4 + case "snappy" => CompressionCodec.Snappy + case other => throw new UnsupportedOperationException(s"invalid codec: $other") } - shuffleWriterBuilder.setCompressionLevel( - CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) - shuffleWriterBuilder.setWriteBufferSize( - CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().min(Int.MaxValue).toInt) + shuffleWriterBuilder.setCodec(codec) + } else { + shuffleWriterBuilder.setCodec(CompressionCodec.None) + } + shuffleWriterBuilder.setCompressionLevel( + CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) + shuffleWriterBuilder.setWriteBufferSize( + CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().min(Int.MaxValue).toInt) - outputPartitioning match { - case p if isSinglePartitioning(p) => - val partitioning = PartitioningOuterClass.SinglePartition.newBuilder() + outputPartitioning match { + case p if isSinglePartitioning(p) => + val partitioning = PartitioningOuterClass.SinglePartition.newBuilder() - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setSinglePartition(partitioning).build()) - case _: HashPartitioning => - val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning] + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setSinglePartition(partitioning).build()) + case _: HashPartitioning => + val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning] - val partitioning = PartitioningOuterClass.HashPartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) + val partitioning = PartitioningOuterClass.HashPartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) - val partitionExprs = hashPartitioning.expressions - .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) + val partitionExprs = hashPartitioning.expressions + .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) - if (partitionExprs.length != hashPartitioning.expressions.length) { - throw new UnsupportedOperationException( - s"Partitioning $hashPartitioning is not supported.") - } + if (partitionExprs.length != hashPartitioning.expressions.length) { + throw new UnsupportedOperationException( + s"Partitioning $hashPartitioning is not supported.") + } - partitioning.addAllHashExpression(partitionExprs.asJava) - - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setHashPartition(partitioning).build()) - case _: RangePartitioning => - val rangePartitioning = outputPartitioning.asInstanceOf[RangePartitioning] - - val partitioning = PartitioningOuterClass.RangePartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) - - // Detect duplicates by tracking expressions directly, similar to DataFusion's LexOrdering - // DataFusion will deduplicate identical sort expressions in LexOrdering, - // so we need to transform boundary rows to match the deduplicated structure - val seenExprs = mutable.HashSet[Expression]() - val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() // (originalIndex, isKept) - - rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) => - if (seenExprs.contains(sortOrder.child)) { - deduplicationMap += (idx -> false) // Will be deduplicated by DataFusion - } else { - seenExprs += sortOrder.child - deduplicationMap += (idx -> true) // Will be kept by DataFusion - } + partitioning.addAllHashExpression(partitionExprs.asJava) + + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setHashPartition(partitioning).build()) + case _: RangePartitioning => + val rangePartitioning = outputPartitioning.asInstanceOf[RangePartitioning] + + val partitioning = PartitioningOuterClass.RangePartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) + + // Detect duplicates by tracking expressions directly, similar to DataFusion's LexOrdering + // DataFusion will deduplicate identical sort expressions in LexOrdering, + // so we need to transform boundary rows to match the deduplicated structure + val seenExprs = mutable.HashSet[Expression]() + val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() // (originalIndex, isKept) + + rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) => + if (seenExprs.contains(sortOrder.child)) { + deduplicationMap += (idx -> false) // Will be deduplicated by DataFusion + } else { + seenExprs += sortOrder.child + deduplicationMap += (idx -> true) // Will be kept by DataFusion } + } - { - // Serialize the ordering expressions for comparisons - val orderingExprs = rangePartitioning.ordering - .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) - if (orderingExprs.length != rangePartitioning.ordering.length) { - throw new UnsupportedOperationException( - s"Partitioning $rangePartitioning is not supported.") - } - partitioning.addAllSortOrders(orderingExprs.asJava) + { + // Serialize the ordering expressions for comparisons + val orderingExprs = rangePartitioning.ordering + .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) + if (orderingExprs.length != rangePartitioning.ordering.length) { + throw new UnsupportedOperationException( + s"Partitioning $rangePartitioning is not supported.") } + partitioning.addAllSortOrders(orderingExprs.asJava) + } - // Convert Spark's sequence of InternalRows that represent partitioning boundaries to - // sequences of Literals, where each outer entry represents a boundary row, and each - // internal entry is a value in that row. In other words, these are stored in row major - // order, not column major - val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType)) - - // Transform boundary rows to match DataFusion's deduplicated structure - val transformedBoundaryExprs: Seq[Seq[Literal]] = - rangePartitionBounds.get.map((row: InternalRow) => { - // For every InternalRow, map its values to Literals - val allLiterals = - row.toSeq(boundarySchema).zip(boundarySchema).map { case (value, valueType) => - Literal(value, valueType) - } - - // Keep only the literals that correspond to non-deduplicated expressions - allLiterals - .zip(deduplicationMap) - .filter(_._2._2) // Keep only where isKept = true - .map(_._1) // Extract the literal + // Convert Spark's sequence of InternalRows that represent partitioning boundaries to + // sequences of Literals, where each outer entry represents a boundary row, and each + // internal entry is a value in that row. In other words, these are stored in row major + // order, not column major + val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType)) + + // Transform boundary rows to match DataFusion's deduplicated structure + val transformedBoundaryExprs: Seq[Seq[Literal]] = + rangePartitionBounds.get.map((row: InternalRow) => { + // For every InternalRow, map its values to Literals + val allLiterals = + row.toSeq(boundarySchema).zip(boundarySchema).map { case (value, valueType) => + Literal(value, valueType) + } + + // Keep only the literals that correspond to non-deduplicated expressions + allLiterals + .zip(deduplicationMap) + .filter(_._2._2) // Keep only where isKept = true + .map(_._1) // Extract the literal + }) + + { + // Convert the sequences of Literals to a collection of serialized BoundaryRows + val boundaryRows: Seq[PartitioningOuterClass.BoundaryRow] = transformedBoundaryExprs + .map((rowLiterals: Seq[Literal]) => { + // Serialize each sequence of Literals as a BoundaryRow + val rowBuilder = PartitioningOuterClass.BoundaryRow.newBuilder(); + val serializedExprs = + rowLiterals.map(lit_value => + QueryPlanSerde.exprToProto(lit_value, outputAttributes).get) + rowBuilder.addAllPartitionBounds(serializedExprs.asJava) + rowBuilder.build() }) + partitioning.addAllBoundaryRows(boundaryRows.asJava) + } - { - // Convert the sequences of Literals to a collection of serialized BoundaryRows - val boundaryRows: Seq[PartitioningOuterClass.BoundaryRow] = transformedBoundaryExprs - .map((rowLiterals: Seq[Literal]) => { - // Serialize each sequence of Literals as a BoundaryRow - val rowBuilder = PartitioningOuterClass.BoundaryRow.newBuilder(); - val serializedExprs = - rowLiterals.map(lit_value => - QueryPlanSerde.exprToProto(lit_value, outputAttributes).get) - rowBuilder.addAllPartitionBounds(serializedExprs.asJava) - rowBuilder.build() - }) - partitioning.addAllBoundaryRows(boundaryRows.asJava) - } - - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setRangePartition(partitioning).build()) + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setRangePartition(partitioning).build()) - case _: RoundRobinPartitioning => - val partitioning = PartitioningOuterClass.RoundRobinPartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) - partitioning.setMaxHashColumns( - CometConf.COMET_EXEC_SHUFFLE_WITH_ROUND_ROBIN_PARTITIONING_MAX_HASH_COLUMNS.get()) + case _: RoundRobinPartitioning => + val partitioning = PartitioningOuterClass.RoundRobinPartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) + partitioning.setMaxHashColumns( + CometConf.COMET_EXEC_SHUFFLE_WITH_ROUND_ROBIN_PARTITIONING_MAX_HASH_COLUMNS.get()) - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setRoundRobinPartition(partitioning).build()) + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setRoundRobinPartition(partitioning).build()) - case _ => - throw new UnsupportedOperationException( - s"Partitioning $outputPartitioning is not supported.") - } + case _ => + throw new UnsupportedOperationException( + s"Partitioning $outputPartitioning is not supported.") + } - shuffleWriterBuilder.setTracingEnabled(CometConf.COMET_TRACING_ENABLED.get()) + shuffleWriterBuilder.setTracingEnabled(CometConf.COMET_TRACING_ENABLED.get()) - val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder() - shuffleWriterOpBuilder - .setShuffleWriter(shuffleWriterBuilder) - .addChildren(opBuilder.setScan(scanBuilder).build()) - .build() - } else { - // There are unsupported scan type - throw new UnsupportedOperationException( - s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.") - } + val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder() + shuffleWriterOpBuilder + .setShuffleWriter(shuffleWriterBuilder) + .addChildren(inputOperator) + .build() } override def stop(success: Boolean): Option[MapStatus] = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala index 2b74e5a168..d0a42d48ce 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType +import org.apache.comet.serde.OperatorOuterClass.Operator + /** * A [[ShuffleDependency]] that allows us to identify the shuffle dependency as a Comet shuffle. */ @@ -49,7 +51,11 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val outputAttributes: Seq[Attribute] = Seq.empty, val shuffleWriteMetrics: Map[String, SQLMetric] = Map.empty, val numParts: Int = 0, - val rangePartitionBounds: Option[Seq[InternalRow]] = None) + val rangePartitionBounds: Option[Seq[InternalRow]] = None, + // For direct native execution: the child's native plan to compose with ShuffleWriter + val childNativePlan: Option[Operator] = None, + val commonByKey: Map[String, Array[Byte]] = Map.empty, + val perPartitionByKey: Map[String, Array[Array[Byte]]] = Map.empty) extends ShuffleDependency[K, V, C]( _rdd, partitioner, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index d4ee4e4ccf..41ab3ac61a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -34,8 +34,9 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Exp import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometPlan, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.{CometIcebergNativeScanExec, CometMetricNode, CometNativeExec, CometNativeScanExec, CometPlan, CometSinkPlaceHolder} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} @@ -52,6 +53,7 @@ import org.apache.comet.{CometConf, CometExplainInfo} import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MODE} import org.apache.comet.CometSparkSessionExtensions.{hasExplainInfo, isCometShuffleManagerEnabled, withInfos} import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde, SupportLevel, Unsupported} +import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.operator.CometSink import org.apache.comet.shims.{CometTypeShim, ShimCometShuffleExchangeExec} @@ -101,9 +103,113 @@ case class CometShuffleExchangeExec( private lazy val serializer: Serializer = new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + /** + * Information about direct native execution optimization. When the child is a single-source + * native plan with a fully native scan (CometNativeScanExec), we can pass the child's native + * plan to the shuffle writer and execute: Scan -> Filter -> Project -> ShuffleWriter all in + * native code, avoiding the JNI round-trip for intermediate batches. + * + * Currently only supports CometNativeScanExec (fully native scans that read files directly via + * DataFusion). JVM scan wrappers (CometScanExec, CometBatchScanExec) still require JNI input + * and are not optimized. + */ + @transient private lazy val directNativeExecutionInfo: Option[DirectNativeExecutionInfo] = { + if (!CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.get()) { + None + } else if (shuffleType != CometNativeShuffle) { + None + } else { + // Check if direct native execution is possible + outputPartitioning match { + case _: RangePartitioning => + // RangePartitioning requires sampling the data to compute bounds, + // which requires executing the child plan. Fall back to current behavior. + None + case _ => + child match { + case nativeChild: CometNativeExec => + // Find input sources using foreachUntilCometInput + val inputSources = scala.collection.mutable.ArrayBuffer.empty[SparkPlan] + nativeChild.foreachUntilCometInput(nativeChild)(inputSources += _) + + // Optimize when all input sources are native scans + // (CometNativeScanExec, CometIcebergNativeScanExec). + // JVM scan wrappers (CometScanExec, CometBatchScanExec) still need JNI input, + // so we don't optimize those. + // Check if the plan contains subqueries (e.g., bloom filters with might_contain). + // Subqueries are registered with the parent execution context ID, but direct + // native shuffle creates a new execution context, so subquery lookup would fail. + val containsSubquery = nativeChild.exists { p => + p.expressions.exists(_.exists(_.isInstanceOf[ScalarSubquery])) + } + if (containsSubquery) { + // Fall back to avoid subquery lookup failures + None + } else { + // Check that ALL input sources are native scans (file-reading, no JNI) + val allNativeScans = inputSources.nonEmpty && inputSources.forall { + case _: CometNativeScanExec => true + case _: CometIcebergNativeScanExec => true + case _ => false + } + if (allNativeScans) { + // Collect per-partition plan data from all native scans + val (commonByKey, perPartitionByKey) = + nativeChild.findAllPlanData(nativeChild) + // All scans must have the same partition count + val partitionCounts = perPartitionByKey.values.map(_.length).toSet + if (partitionCounts.size <= 1) { + val numPartitions = partitionCounts.headOption.getOrElse(0) + if (numPartitions == 0) { + // Empty table (no data files) - fall back to normal execution + None + } else { + Some( + DirectNativeExecutionInfo( + nativeChild.nativeOp, + numPartitions, + commonByKey, + perPartitionByKey)) + } + } else { + None // Partition count mismatch across scans + } + } else { + None + } + } + case _ => + None + } + } + } + } + + /** + * Returns true if direct native execution optimization is being used for this shuffle. This is + * primarily intended for testing to verify the optimization is applied correctly. + */ + def isDirectNativeExecution: Boolean = directNativeExecutionInfo.isDefined + + /** + * Creates an RDD that provides empty iterators for each partition. Used when direct native + * execution is enabled - the shuffle writer will execute the full native plan which reads data + * directly (no JNI input needed). + */ + private def createEmptyPartitionRDD(numPartitions: Int): RDD[ColumnarBatch] = { + sparkContext.parallelize(Seq.empty[ColumnarBatch], numPartitions) + } + @transient lazy val inputRDD: RDD[_] = if (shuffleType == CometNativeShuffle) { - // CometNativeShuffle assumes that the input plan is Comet plan. - child.executeColumnar() + directNativeExecutionInfo match { + case Some(info) => + // Direct native execution: create an RDD with empty partitions. + // The shuffle writer will execute the full native plan which reads data directly. + createEmptyPartitionRDD(info.numPartitions) + case None => + // Fall back to current behavior: execute child and pass intermediate batches + child.executeColumnar() + } } else if (shuffleType == CometColumnarShuffle) { // CometColumnarShuffle uses Spark's row-based execute() API. For Spark row-based plans, // rows flow directly. For Comet native plans, their doExecute() wraps with ColumnarToRowExec @@ -154,7 +260,10 @@ case class CometShuffleExchangeExec( child.output, outputPartitioning, serializer, - metrics) + metrics, + directNativeExecutionInfo.map(_.childNativePlan), + directNativeExecutionInfo.map(_.commonByKey).getOrElse(Map.empty), + directNativeExecutionInfo.map(_.perPartitionByKey).getOrElse(Map.empty)) metrics("numPartitions").set(dep.partitioner.numPartitions) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates( @@ -631,7 +740,11 @@ object CometShuffleExchangeExec outputAttributes: Seq[Attribute], outputPartitioning: Partitioning, serializer: Serializer, - metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + metrics: Map[String, SQLMetric], + childNativePlan: Option[Operator] = None, + commonByKey: Map[String, Array[Byte]] = Map.empty, + perPartitionByKey: Map[String, Array[Array[Byte]]] = Map.empty) + : ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { val numParts = rdd.getNumPartitions // The code block below is mostly brought over from @@ -698,7 +811,10 @@ object CometShuffleExchangeExec outputAttributes = outputAttributes, shuffleWriteMetrics = metrics, numParts = numParts, - rangePartitionBounds = rangePartitionBounds) + rangePartitionBounds = rangePartitionBounds, + childNativePlan = childNativePlan, + commonByKey = commonByKey, + perPartitionByKey = perPartitionByKey) dependency } @@ -904,3 +1020,17 @@ object CometShuffleExchangeExec dependency } } + +/** + * Information needed for direct native execution optimization. + * + * @param childNativePlan + * The child's native operator plan to compose with ShuffleWriter + * @param numPartitions + * The number of partitions (from the underlying scan) + */ +private[shuffle] case class DirectNativeExecutionInfo( + childNativePlan: Operator, + numPartitions: Int, + commonByKey: Map[String, Array[Byte]], + perPartitionByKey: Map[String, Array[Array[Byte]]]) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala index c8f2199d53..4cf20a1243 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala @@ -239,7 +239,10 @@ class CometShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { mapId, context, metrics, - dep.rangePartitionBounds) + dep.rangePartitionBounds, + dep.childNativePlan, + dep.commonByKey, + dep.perPartitionByKey) case bypassMergeSortHandle: CometBypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => val bypassDep = bypassMergeSortHandle.dependency.asInstanceOf[CometShuffleDependency[_, _, _]] diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a4c8d178c7..2c3c66b6e8 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -661,7 +661,7 @@ abstract class CometNativeExec extends CometExec { * @return * (commonByKey, perPartitionByKey) - common data is shared, per-partition varies */ - private def findAllPlanData( + private[comet] def findAllPlanData( plan: SparkPlan): (Map[String, Array[Byte]], Map[String, Array[Array[Byte]]]) = { plan match { case iceberg: CometIcebergNativeScanExec => diff --git a/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala new file mode 100644 index 0000000000..c08ada38a9 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.exec + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.sql.{CometTestBase, DataFrame} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions.col + +import org.apache.comet.CometConf + +/** + * Test suite for the direct native shuffle execution optimization. + * + * This optimization allows the native shuffle writer to directly execute the child native plan + * instead of reading intermediate batches via JNI. This avoids the JNI round-trip for + * single-source native plans (e.g., Scan -> Filter -> Project -> Shuffle). + */ +class CometDirectNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "native", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion", + CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.key -> "true") { + testFun + } + } + } + + import testImplicits._ + + test("direct native execution: simple scan with hash partitioning") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + // Verify the optimization is applied + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1, "Expected exactly one shuffle") + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should be enabled for single-source native scan") + + // Verify correctness + checkSparkAnswer(df) + } + } + + test("direct native execution: scan with filter and project") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i.toString)), "tbl") { + val df = sql("SELECT _1, _2 * 2 as doubled FROM tbl WHERE _1 > 10") + .repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should work with filter and project") + + checkSparkAnswer(df) + } + } + + test("direct native execution: single partition") { + withParquetTable((0 until 50).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(1) + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should work with single partition") + + checkSparkAnswer(df) + } + } + + test("direct native execution: multiple hash columns") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i.toString)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1", $"_2") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should work with multiple hash columns") + + checkSparkAnswer(df) + } + } + + test("direct native execution: aggregation before shuffle") { + withParquetTable((0 until 100).map(i => (i % 10, (i + 1).toLong)), "tbl") { + val df = sql("SELECT _1, SUM(_2) as total FROM tbl GROUP BY _1") + .repartition(5, col("_1")) + + // This involves partial aggregation -> shuffle -> final aggregation + // The direct native execution applies to the shuffle that reads from the partial agg + checkSparkAnswer(df) + } + } + + test("direct native execution disabled: config is false") { + withSQLConf(CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.key -> "false") { + withParquetTable((0 until 50).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + !shuffles.head.isDirectNativeExecution, + "Direct native execution should be disabled when config is false") + + checkSparkAnswer(df) + } + } + } + + test("direct native execution disabled: range partitioning") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartitionByRange(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + !shuffles.head.isDirectNativeExecution, + "Direct native execution should not be used for range partitioning") + + checkSparkAnswer(df) + } + } + + test("direct native execution disabled: JVM columnar shuffle mode") { + withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { + withParquetTable((0 until 50).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + // JVM shuffle mode uses CometColumnarShuffle, not CometNativeShuffle + val shuffles = findShuffleExchanges(df) + shuffles.foreach { shuffle => + assert( + !shuffle.isDirectNativeExecution, + "Direct native execution should not be used with JVM shuffle mode") + } + + checkSparkAnswer(df) + } + } + } + + test("direct native execution: multiple shuffles in same query") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl") + .repartition(10, $"_1") + .select($"_1", $"_2" + 1 as "_2_plus") + .repartition(5, $"_2_plus") + + // First shuffle reads from scan, second reads from previous shuffle output + // Only the first shuffle should use direct native execution + // AQE might combine some shuffles, so just verify results are correct + checkSparkAnswer(df) + } + } + + test("direct native execution: various data types") { + withParquetTable( + (0 until 50).map(i => + (i, i.toLong, i.toFloat, i.toDouble, i.toString, i % 2 == 0, BigDecimal(i))), + "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert(shuffles.head.isDirectNativeExecution) + + checkSparkAnswer(df) + } + } + + test("direct native execution: complex filter and multiple projections") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i % 5)), "tbl") { + val df = sql(""" + |SELECT _1 * 2 as doubled, + | _2 + _3 as sum_col, + | _1 + _2 as combined + |FROM tbl + |WHERE _1 > 20 AND _3 < 3 + |""".stripMargin) + .repartition(10, col("doubled")) + + // Note: Native shuffle might fall back depending on expression support + // Just verify correctness - the optimization is best-effort + checkSparkAnswer(df) + } + } + + test("direct native execution: results match non-optimized path") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i.toString)), "tbl") { + // Run with optimization enabled + val dfOptimized = sql("SELECT _1, _2 FROM tbl WHERE _1 > 50").repartition(10, $"_1") + val optimizedResult = dfOptimized.collect().sortBy(_.getInt(0)) + + // Run with optimization disabled and collect results + var nonOptimizedResult: Array[org.apache.spark.sql.Row] = Array.empty + withSQLConf(CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.key -> "false") { + val dfNonOptimized = sql("SELECT _1, _2 FROM tbl WHERE _1 > 50").repartition(10, $"_1") + nonOptimizedResult = dfNonOptimized.collect().sortBy(_.getInt(0)) + } + + // Results should match + assert(optimizedResult.length == nonOptimizedResult.length, "Row counts should match") + optimizedResult.zip(nonOptimizedResult).foreach { case (opt, nonOpt) => + assert(opt == nonOpt, s"Rows should match: $opt vs $nonOpt") + } + } + } + + test("direct native execution: large number of partitions") { + withParquetTable((0 until 1000).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(201, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert(shuffles.head.isDirectNativeExecution) + + checkSparkAnswer(df) + } + } + + test("direct native execution: empty table") { + withParquetTable(Seq.empty[(Int, Long)], "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + // Should handle empty tables gracefully + val result = df.collect() + assert(result.isEmpty) + } + } + + test("direct native execution: all rows filtered out") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl WHERE _1 > 1000").repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert(shuffles.head.isDirectNativeExecution) + + val result = df.collect() + assert(result.isEmpty, "Result should be empty when all rows are filtered") + } + } + + // TODO: Add Iceberg native scan test when Iceberg test infrastructure is available + // in this suite. CometIcebergNativeScanExec is supported by the optimization but + // requires SparkCatalog setup. See CometIcebergSuite for patterns. + + test("direct native execution: join of two native scans") { + withParquetTable((0 until 100).map(i => (i, s"left_$i")), "left_tbl") { + withParquetTable((0 until 100).map(i => (i, s"right_$i")), "right_tbl") { + // Broadcast join with two native scans + // The join itself may or may not use direct native execution depending on + // whether broadcast creates a non-native-scan input, but the query should + // execute correctly regardless + val df = sql(""" + |SELECT l._1, l._2, r._2 + |FROM left_tbl l JOIN right_tbl r ON l._1 = r._1 + |WHERE l._1 > 50 + |""".stripMargin) + .repartition(10, col("_1")) + + checkSparkAnswer(df) + } + } + } + + test("direct native execution disabled: shuffle input source") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + // Force a shuffle before the final shuffle by using a repartition + aggregation + // The second shuffle reads from the first shuffle's output (not a native scan) + val df = sql("SELECT _1, SUM(_2) as s FROM tbl GROUP BY _1") + .repartition(5, col("_1")) + .filter(col("s") > 10) + .repartition(3, col("_1")) + + // The final shuffle should NOT use direct native execution because its input + // comes from a shuffle read, not a native scan + checkSparkAnswer(df) + } + } + + /** + * Helper method to find CometShuffleExchangeExec nodes in a DataFrame's execution plan. + */ + private def findShuffleExchanges(df: DataFrame): Seq[CometShuffleExchangeExec] = { + val plan = stripAQEPlan(df.queryExecution.executedPlan) + plan.collect { case s: CometShuffleExchangeExec => s } + } +}