diff --git a/rust/cubestore/.gitignore b/rust/cubestore/.gitignore index 66d0436bd8f34..64f8a320fdced 100644 --- a/rust/cubestore/.gitignore +++ b/rust/cubestore/.gitignore @@ -10,3 +10,6 @@ cubestore/target cubesql/target cubestore-sql-tests/data/** cubestore/db-tmp +# RocksDB scratch dirs left by metastore unit tests (run from the crate root) +/cubestore/test-*-local/ +/cubestore/test-*-upstream/ diff --git a/rust/cubestore/Cargo.lock b/rust/cubestore/Cargo.lock index 8d36a7d26c016..c5785ec686690 100644 --- a/rust/cubestore/Cargo.lock +++ b/rust/cubestore/Cargo.lock @@ -1758,7 +1758,7 @@ checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" [[package]] name = "datafusion" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "arrow-ipc", @@ -1811,7 +1811,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "async-trait", @@ -1830,7 +1830,7 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "async-trait", @@ -1851,7 +1851,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "ahash 0.8.11", "arrow", @@ -1874,7 +1874,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "log", "tokio", @@ -1883,7 +1883,7 @@ dependencies = [ [[package]] name = "datafusion-datasource" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "async-compression 0.4.17", @@ -1916,12 +1916,12 @@ dependencies = [ [[package]] name = "datafusion-doc" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" [[package]] name = "datafusion-execution" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "dashmap", @@ -1941,7 +1941,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "chrono", @@ -1961,7 +1961,7 @@ dependencies = [ [[package]] name = "datafusion-expr-common" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "datafusion-common", @@ -1973,7 +1973,7 @@ dependencies = [ [[package]] name = "datafusion-functions" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "arrow-buffer", @@ -2001,7 +2001,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "ahash 0.8.11", "arrow", @@ -2021,7 +2021,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "ahash 0.8.11", "arrow", @@ -2033,7 +2033,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "arrow-ord", @@ -2053,7 +2053,7 @@ dependencies = [ [[package]] name = "datafusion-functions-table" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "async-trait", @@ -2068,7 +2068,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "datafusion-common", "datafusion-doc", @@ -2084,7 +2084,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -2093,7 +2093,7 @@ dependencies = [ [[package]] name = "datafusion-macros" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "datafusion-expr", "quote", @@ -2103,7 +2103,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "chrono", @@ -2121,7 +2121,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "ahash 0.8.11", "arrow", @@ -2142,7 +2142,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "ahash 0.8.11", "arrow", @@ -2155,7 +2155,7 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "datafusion-common", @@ -2173,7 +2173,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "ahash 0.8.11", "arrow", @@ -2205,7 +2205,7 @@ dependencies = [ [[package]] name = "datafusion-proto" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "chrono", @@ -2220,7 +2220,7 @@ dependencies = [ [[package]] name = "datafusion-proto-common" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "datafusion-common", @@ -2230,7 +2230,7 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "bigdecimal 0.4.8", diff --git a/rust/cubestore/cubestore-sql-tests/src/tests.rs b/rust/cubestore/cubestore-sql-tests/src/tests.rs index 68925496d1721..b2cd81474808d 100644 --- a/rust/cubestore/cubestore-sql-tests/src/tests.rs +++ b/rust/cubestore/cubestore-sql-tests/src/tests.rs @@ -154,6 +154,8 @@ pub fn sql_tests(prefix: &str) -> Vec<(&'static str, TestFn)> { t("planning_inplace_aggregate", planning_inplace_aggregate), t("planning_hints", planning_hints), t("planning_inplace_aggregate2", planning_inplace_aggregate2), + t("planning_topk_hash_aggregate", planning_topk_hash_aggregate), + t("topk_hash_aggregate_trim", topk_hash_aggregate_trim), t("topk_large_inputs", topk_large_inputs), t("partitioned_index", partitioned_index), t( @@ -424,6 +426,8 @@ lazy_static::lazy_static! { "limit_pushdown_group_having", "limit_pushdown_group_nonprefix_order", "prefilter_chunks_shared_scan", + "planning_topk_hash_aggregate", + "topk_hash_aggregate_trim", ].into_iter().map(ToOwned::to_owned).collect(); } @@ -3200,6 +3204,194 @@ async fn planning_inplace_aggregate(service: Box) -> Result<(), C Ok(()) } +async fn planning_topk_hash_aggregate(service: Box) -> Result<(), CubeError> { + service.exec_query("CREATE SCHEMA s").await?; + service + .exec_query("CREATE TABLE s.Data(url text, day int, hits int)") + .await?; + service + .exec_query("CREATE TABLE s.D3(a int, b int, c int, h int)") + .await?; + + // GROUP BY a non-indexed column -> hash (Linear) partial aggregate; ORDER BY the group + // column with a LIMIT -> the worker partial aggregate is replaced by GroupByLimitAggregate. + let p = service + .plan_query("SELECT day, SUM(hits) FROM s.Data GROUP BY 1 ORDER BY 1 LIMIT 10") + .await?; + let pp = pp_phys_plan_ext(p.worker.as_ref(), &PPOptions::none()); + assert!( + pp.contains("GroupByLimitAggregate, k: 10, factor: 2,"), + "expected GroupByLimitAggregate on the worker, got:\n{}", + pp + ); + + // LIMIT + OFFSET -> k = limit + offset. + let p = service + .plan_query("SELECT day, SUM(hits) FROM s.Data GROUP BY 1 ORDER BY 1 LIMIT 10 OFFSET 5") + .await?; + let pp = pp_phys_plan_ext(p.worker.as_ref(), &PPOptions::none()); + assert!( + pp.contains("GroupByLimitAggregate, k: 15, factor: 2,"), + "expected k=15 (limit+offset), got:\n{}", + pp + ); + + // ORDER BY an aggregate (not a group-by column) -> no trim. + let p = service + .plan_query("SELECT day, SUM(hits) FROM s.Data GROUP BY 1 ORDER BY 2 DESC LIMIT 10") + .await?; + let pp = pp_phys_plan_ext(p.worker.as_ref(), &PPOptions::none()); + assert!( + !pp.contains("GroupByLimitAggregate"), + "did not expect GroupByLimitAggregate when ordering by an aggregate, got:\n{}", + pp + ); + + // No LIMIT -> no trim. + let p = service + .plan_query("SELECT day, SUM(hits) FROM s.Data GROUP BY 1 ORDER BY 1") + .await?; + let pp = pp_phys_plan_ext(p.worker.as_ref(), &PPOptions::none()); + assert!( + !pp.contains("GroupByLimitAggregate"), + "did not expect GroupByLimitAggregate without a limit, got:\n{}", + pp + ); + + // ORDER BY a proper SUBSET of GROUP BY (b out of b, c). The worker cut and the router sort must + // both use the total order T = [b, c]: the worker trim order carries the tie-break column c, and + // the router's global Sort is extended with c so its top-k matches the global top-k by T. + let p = service + .plan_query("SELECT b, c, SUM(h) FROM s.D3 GROUP BY 1, 2 ORDER BY 1 LIMIT 3") + .await?; + let worker_pp = pp_phys_plan_ext(p.worker.as_ref(), &PPOptions::none()); + assert!( + worker_pp.contains("GroupByLimitAggregate, k: 3, factor: 2,") + && worker_pp.contains("(0, SortOptions { descending: false, nulls_first: false })") + && worker_pp.contains("(1, SortOptions { descending: false, nulls_first: true })"), + "expected worker trim order [b, c] totalized, got:\n{}", + worker_pp + ); + let router_pp = pp_phys_plan_ext( + p.router.as_ref(), + &PPOptions { + show_sort_by: true, + ..PPOptions::none() + }, + ); + assert!( + router_pp.contains("b@0") && router_pp.contains("c@1"), + "expected router Sort extended with the tie-break column c, got:\n{}", + router_pp + ); + + // Bare LIMIT (no ORDER BY) on a non-indexed group column: the limit can't ride the index, so the + // worker still trims to the smallest groups by the full group key -- "any k" made deterministic. + let p = service + .plan_query("SELECT day, SUM(hits) FROM s.Data GROUP BY 1 LIMIT 10") + .await?; + let pp = pp_phys_plan_ext(p.worker.as_ref(), &PPOptions::none()); + assert!( + pp.contains("GroupByLimitAggregate, k: 10, factor: 2,") + && pp.contains("(0, SortOptions { descending: false, nulls_first: false })"), + "expected GroupByLimitAggregate on a bare LIMIT, got:\n{}", + pp + ); + + // UNION ALL + bare LIMIT: the per-branch trim descriptor must survive the cluster-send pull-up + // over the union so the worker still trims above the union. + service + .exec_query("CREATE TABLE s.Data2(url text, day int, hits int)") + .await?; + let p = service + .plan_query( + "SELECT day, SUM(hits) FROM \ + (SELECT * FROM s.Data UNION ALL SELECT * FROM s.Data2) u GROUP BY 1 LIMIT 10", + ) + .await?; + let pp = pp_phys_plan_ext(p.worker.as_ref(), &PPOptions::none()); + assert!( + pp.contains("GroupByLimitAggregate, k: 10, factor: 2,") && pp.contains("Union"), + "expected GroupByLimitAggregate over the Union, got:\n{}", + pp + ); + + Ok(()) +} + +async fn topk_hash_aggregate_trim(service: Box) -> Result<(), CubeError> { + service.exec_query("CREATE SCHEMA s").await?; + service + .exec_query("CREATE TABLE s.Data(a int, b int, hits int)") + .await?; + // 12 distinct (a, b) groups, each with two rows so partial aggregation actually groups. + // With k=3 and factor=2 the trim activates (g=12 > 6) but the result must match a full + // top-k. ORDER BY a (a proper subset of GROUP BY a, b) exercises totalization: the worker + // breaks ties on a by b so the router still receives every needed partial state. + service + .exec_query( + "INSERT INTO s.Data(a, b, hits) VALUES \ + (1,1,10),(1,1,5),(1,2,1),(1,2,2),\ + (2,1,7),(2,1,3),(2,2,4),(2,2,6),\ + (3,1,8),(3,1,2),(3,2,9),(3,2,1),\ + (4,1,1),(4,1,1),(4,2,1),(4,2,1),\ + (5,1,1),(5,1,1),(5,2,1),(5,2,1),\ + (6,1,1),(6,1,1),(6,2,1),(6,2,1)", + ) + .await?; + + // ORDER BY a, b LIMIT 3 (ascending): smallest three groups by (a, b). + let r = service + .exec_query("SELECT a, b, SUM(hits) FROM s.Data GROUP BY 1, 2 ORDER BY 1, 2 LIMIT 3") + .await?; + assert_eq!(to_rows(&r), rows(&[(1, 1, 15), (1, 2, 3), (2, 1, 10)])); + + // ORDER BY a, b DESC LIMIT 3: largest three groups by (a, b). + let r = service + .exec_query( + "SELECT a, b, SUM(hits) FROM s.Data GROUP BY 1, 2 ORDER BY 1 DESC, 2 DESC LIMIT 3", + ) + .await?; + assert_eq!(to_rows(&r), rows(&[(6, 2, 2), (6, 1, 2), (5, 2, 2)])); + + // ORDER BY a only (a proper subset of GROUP BY a, b), LIMIT 2. The selected group SET is + // deterministic (both groups of a=1), but the intra-tie row order is not, so assert as a set. + // Each returned group must carry its complete sum regardless of cross-worker tie-breaking, + // which is what totalization (append b to the cut order) guarantees. + let r = service + .exec_query("SELECT a, b, SUM(hits) FROM s.Data GROUP BY 1, 2 ORDER BY 1 LIMIT 2") + .await?; + let got = to_rows(&r); + assert_eq!(got.len(), 2, "expected 2 rows, got: {:?}", got); + for expected in rows(&[(1, 1, 15), (1, 2, 3)]) { + assert!( + got.contains(&expected), + "missing {:?} in {:?}", + expected, + got + ); + } + + // Bare LIMIT 3 (no ORDER BY): the trim orders by the full group key, so "any 3" resolves to the + // 3 smallest by (a, b). The result order is unspecified, but the group SET and each group's full + // sum must be exact -- the latter guards against undercounting a group split across workers. + let r = service + .exec_query("SELECT a, b, SUM(hits) FROM s.Data GROUP BY 1, 2 LIMIT 3") + .await?; + let got = to_rows(&r); + assert_eq!(got.len(), 3, "expected 3 rows, got: {:?}", got); + for expected in rows(&[(1, 1, 15), (1, 2, 3), (2, 1, 10)]) { + assert!( + got.contains(&expected), + "missing {:?} in {:?}", + expected, + got + ); + } + + Ok(()) +} + async fn planning_hints(service: Box) -> Result<(), CubeError> { service.exec_query("CREATE SCHEMA s").await?; service diff --git a/rust/cubestore/cubestore/Cargo.toml b/rust/cubestore/cubestore/Cargo.toml index effab752735b8..167ca2a37690e 100644 --- a/rust/cubestore/cubestore/Cargo.toml +++ b/rust/cubestore/cubestore/Cargo.toml @@ -28,10 +28,10 @@ cubezetasketch = { path = "../cubezetasketch" } cubedatasketches = { path = "../cubedatasketches" } cubeshared = { path = "../../cube/cubeshared" } cuberpc = { path = "../cuberpc" } -datafusion = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cube-46.0.1", features = ["serde"] } -datafusion-datasource = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cube-46.0.1" } -datafusion-proto = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cube-46.0.1" } -datafusion-proto-common = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cube-46.0.1" } +datafusion = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cubestore-hash-aggregate-limit", features = ["serde"] } +datafusion-datasource = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cubestore-hash-aggregate-limit" } +datafusion-proto = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cubestore-hash-aggregate-limit" } +datafusion-proto-common = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cubestore-hash-aggregate-limit" } csv = "1.1.3" bytes = "1.6.0" serde_json = "1.0.56" diff --git a/rust/cubestore/cubestore/src/config/mod.rs b/rust/cubestore/cubestore/src/config/mod.rs index 5cba4c1222959..9a69ffd6df6b0 100644 --- a/rust/cubestore/cubestore/src/config/mod.rs +++ b/rust/cubestore/cubestore/src/config/mod.rs @@ -585,6 +585,14 @@ pub trait ConfigObj: DIService { /// streaming split never drops rows (the first child is the low catch-all, the last /// the high one), and the legacy per-chunk path performed no such metadata check. fn repartition_check_overlapping_children(&self) -> bool; + /// Master gate for string dictionary-encoding work: reading string columns as + /// `DictionaryArray` and the dictionary-aware inline aggregate paths. Off by default while + /// the feature is built up incrementally behind this flag. + fn dictionary_encoding_enabled(&self) -> bool; + /// Factor `f` controlling when the worker-side partial hash aggregate trims its output to the + /// top-k groups. Trimming happens only when the number of local groups exceeds `f * k`, where + /// `k = limit + offset`. `0` disables the optimization. + fn group_by_limit_factor(&self) -> usize; fn allow_decimal128(&self) -> bool; @@ -745,6 +753,8 @@ pub struct ConfigObjImpl { pub repartition_merge_max_input_files: usize, pub repartition_merge_max_rows: u64, pub repartition_check_overlapping_children: bool, + pub dictionary_encoding_enabled: bool, + pub group_by_limit_factor: usize, pub allow_decimal128: bool, pub enable_remove_orphaned_remote_files: bool, pub enable_startup_warmup: bool, @@ -1086,6 +1096,13 @@ impl ConfigObj for ConfigObjImpl { fn repartition_check_overlapping_children(&self) -> bool { self.repartition_check_overlapping_children } + fn dictionary_encoding_enabled(&self) -> bool { + self.dictionary_encoding_enabled + } + + fn group_by_limit_factor(&self) -> usize { + self.group_by_limit_factor + } fn allow_decimal128(&self) -> bool { self.allow_decimal128 @@ -1783,6 +1800,11 @@ impl Config { "CUBESTORE_REPARTITION_CHECK_OVERLAPPING_CHILDREN", false, ), + dictionary_encoding_enabled: env_bool("CUBESTORE_DICTIONARY_ENCODING", false), + group_by_limit_factor: env_parse( + "CUBESTORE_GROUP_BY_LIMIT_FACTOR", + 0, + ), allow_decimal128: env_bool("CUBESTORE_ALLOW_DECIMAL128", false), enable_remove_orphaned_remote_files: env_bool( "CUBESTORE_ENABLE_REMOVE_ORPHANED_REMOTE_FILES", @@ -2039,6 +2061,8 @@ impl Config { repartition_merge_max_input_files: 50, repartition_merge_max_rows: 4_000_000, repartition_check_overlapping_children: false, + dictionary_encoding_enabled: false, + group_by_limit_factor: 2, allow_decimal128: false, enable_remove_orphaned_remote_files: false, enable_startup_warmup: true, @@ -2734,10 +2758,6 @@ impl Config { self.injector .register_typed_with_default::(async move |i| { - let push_partial_aggregate_below_merge = i - .get_service_typed::() - .await - .push_partial_aggregate_below_merge_enabled(); QueryExecutorImpl::new( i.get_service_typed::() .await @@ -2745,7 +2765,7 @@ impl Config { .clone(), i.get_service_typed().await, i.get_service_typed().await, - push_partial_aggregate_below_merge, + i.get_service_typed::().await, ) }) .await; diff --git a/rust/cubestore/cubestore/src/metastore/mod.rs b/rust/cubestore/cubestore/src/metastore/mod.rs index 5d6fadd650638..fb0aa7be43702 100644 --- a/rust/cubestore/cubestore/src/metastore/mod.rs +++ b/rust/cubestore/cubestore/src/metastore/mod.rs @@ -558,28 +558,40 @@ impl Into for Column { } } +impl Column { + /// Arrow field for this column. When `dictionary_encoding` is set, `String` columns are + /// exposed as `Dictionary(Int32, Utf8)` so they flow dictionary-encoded through the plan; + /// otherwise they are plain `Utf8`. All other types are unaffected. + pub fn as_arrow_field(&self, dictionary_encoding: bool) -> Field { + let data_type = match self.column_type { + ColumnType::String => { + if dictionary_encoding { + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + } else { + DataType::Utf8 + } + } + ColumnType::Int => DataType::Int64, + ColumnType::Int96 => DataType::Decimal128(38, 0), + ColumnType::Timestamp => DataType::Timestamp(Microsecond, None), + ColumnType::Boolean => DataType::Boolean, + ColumnType::Decimal { scale, precision } => { + DataType::Decimal128(precision as u8, scale as i8) + } + ColumnType::Decimal96 { scale, precision } => { + DataType::Decimal128(precision as u8, scale as i8) + } + ColumnType::Bytes => DataType::Binary, + ColumnType::HyperLogLog(_) => DataType::Binary, + ColumnType::Float => DataType::Float64, + }; + Field::new(self.name.as_str(), data_type, true) + } +} + impl<'a> Into for &'a Column { fn into(self) -> Field { - Field::new( - self.name.as_str(), - match self.column_type { - ColumnType::String => DataType::Utf8, - ColumnType::Int => DataType::Int64, - ColumnType::Int96 => DataType::Decimal128(38, 0), - ColumnType::Timestamp => DataType::Timestamp(Microsecond, None), - ColumnType::Boolean => DataType::Boolean, - ColumnType::Decimal { scale, precision } => { - DataType::Decimal128(precision as u8, scale as i8) - } - ColumnType::Decimal96 { scale, precision } => { - DataType::Decimal128(precision as u8, scale as i8) - } - ColumnType::Bytes => DataType::Binary, - ColumnType::HyperLogLog(_) => DataType::Binary, - ColumnType::Float => DataType::Float64, - }, - true, - ) + self.as_arrow_field(false) } } diff --git a/rust/cubestore/cubestore/src/queryplanner/group_by_limit_aggregate/dict_remap.rs b/rust/cubestore/cubestore/src/queryplanner/group_by_limit_aggregate/dict_remap.rs new file mode 100644 index 0000000000000..e8ff9b79fb577 --- /dev/null +++ b/rust/cubestore/cubestore/src/queryplanner/group_by_limit_aggregate/dict_remap.rs @@ -0,0 +1,170 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion::arrow::array::{ + Array, ArrayRef, DictionaryArray, Int32Array, Int32Builder, StringArray, +}; +use datafusion::arrow::compute::take; +use datafusion::arrow::datatypes::{DataType, Int32Type}; +use datafusion::error::{DataFusionError, Result as DFResult}; + +/// True for the only dictionary layout CubeStore produces for string group keys. +pub(crate) fn is_int32_utf8_dict(dt: &DataType) -> bool { + matches!(dt, DataType::Dictionary(k, v) + if k.as_ref() == &DataType::Int32 && v.as_ref() == &DataType::Utf8) +} + +/// Accumulates a global `String -> id` mapping across batches so a dictionary-encoded group column +/// can be grouped as `Int32` global ids on DataFusion's fast primitive path, instead of +/// materializing the string on every row. The per-batch string work is proportional to the batch's +/// distinct dictionary values, not its row count. Null dictionary entries and null keys stay null. +pub(crate) struct GlobalDict { + value_to_id: HashMap, i32>, + values: Vec>, +} + +impl GlobalDict { + pub fn new() -> Self { + Self { + value_to_id: HashMap::new(), + values: Vec::new(), + } + } + + fn intern_value(&mut self, v: &str) -> i32 { + if let Some(id) = self.value_to_id.get(v) { + return *id; + } + let id = self.values.len() as i32; + // One allocation shared between the map key and the values vec. + let key: Arc = Arc::from(v); + self.values.push(key.clone()); + self.value_to_id.insert(key, id); + id + } + + /// Remap a `Dictionary(Int32, Utf8)` array to an `Int32Array` of global ids. + pub fn remap(&mut self, array: &ArrayRef) -> DFResult { + let dict = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "GlobalDict::remap expected Dictionary(Int32)".to_string(), + ) + })?; + let local_values = dict + .values() + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal("GlobalDict::remap expected Utf8 values".to_string()) + })?; + + // local id -> global id, interning each distinct value once; a null dictionary entry is a + // null in this map. Built once per batch (O(distinct values)). + let mut builder = Int32Builder::with_capacity(local_values.len()); + for i in 0..local_values.len() { + if local_values.is_null(i) { + builder.append_null(); + } else { + builder.append_value(self.intern_value(local_values.value(i))); + } + } + let local_to_global = builder.finish(); + + // Gather the global id per row via a vectorized take: null keys and null dictionary entries + // both propagate to null, matching how the string path groups nulls. + Ok(take(&local_to_global, dict.keys(), None)?) + } + + /// Rebuild a `Dictionary(Int32, Utf8)` array from an `Int32Array` of global ids emitted by the + /// group table; the values are the full accumulated global dictionary. + pub fn rebuild(&self, ids: &ArrayRef) -> DFResult { + let ids = ids.as_any().downcast_ref::().ok_or_else(|| { + DataFusionError::Internal("GlobalDict::rebuild expected Int32 ids".to_string()) + })?; + let values = StringArray::from_iter_values(self.values.iter()); + let dict = DictionaryArray::::try_new(ids.clone(), Arc::new(values))?; + Ok(Arc::new(dict)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn dict(values: Vec>, keys: Vec>) -> ArrayRef { + let values = StringArray::from(values); + let keys = Int32Array::from(keys); + Arc::new(DictionaryArray::::try_new(keys, Arc::new(values)).unwrap()) + } + + fn ids(a: &ArrayRef) -> Int32Array { + a.as_any().downcast_ref::().unwrap().clone() + } + + fn rebuilt_strings(a: &ArrayRef) -> Vec> { + let d = a + .as_any() + .downcast_ref::>() + .unwrap(); + let v = d.values().as_any().downcast_ref::().unwrap(); + d.keys() + .iter() + .map(|k| k.map(|k| v.value(k as usize).to_string())) + .collect() + } + + #[test] + fn remaps_to_consistent_global_ids_across_batches() { + let mut gd = GlobalDict::new(); + // batch 1: local dict ["b", "a"], rows b, a, b + let b1 = ids(&gd + .remap(&dict( + vec![Some("b"), Some("a")], + vec![Some(0), Some(1), Some(0)], + )) + .unwrap()); + // batch 2: a DIFFERENT local dict ["a", "c"], rows c, a -- "a" must reuse its global id + let b2 = ids(&gd + .remap(&dict(vec![Some("a"), Some("c")], vec![Some(1), Some(0)])) + .unwrap()); + + assert_eq!(b1.values(), &[0, 1, 0]); // b=0, a=1 (first-seen) + assert_eq!(b2.value(1), b1.value(1)); // same string "a" -> same global id across batches + assert_ne!(b2.value(0), b1.value(0)); // "c" is a new id + + // rebuild over the accumulated global ids yields the original strings + let all: ArrayRef = Arc::new(Int32Array::from(vec![ + b1.value(0), + b1.value(1), + b2.value(0), + ])); + assert_eq!( + rebuilt_strings(&gd.rebuild(&all).unwrap()), + vec![ + Some("b".to_string()), + Some("a".to_string()), + Some("c".to_string()) + ] + ); + } + + #[test] + fn null_keys_and_null_entries_stay_null() { + let mut gd = GlobalDict::new(); + // local dict ["x", null]; rows: x, null-key, points-to-null-entry + let r = gd + .remap(&dict(vec![Some("x"), None], vec![Some(0), None, Some(1)])) + .unwrap(); + let r = ids(&r); + assert!(r.is_valid(0)); + assert!(r.is_null(1)); + assert!(r.is_null(2)); + assert_eq!( + rebuilt_strings(&gd.rebuild(&(Arc::new(r) as ArrayRef)).unwrap()), + vec![Some("x".to_string()), None, None] + ); + } +} diff --git a/rust/cubestore/cubestore/src/queryplanner/group_by_limit_aggregate/group_by_limit_aggregate_stream.rs b/rust/cubestore/cubestore/src/queryplanner/group_by_limit_aggregate/group_by_limit_aggregate_stream.rs new file mode 100644 index 0000000000000..74c7d8ed4c72b --- /dev/null +++ b/rust/cubestore/cubestore/src/queryplanner/group_by_limit_aggregate/group_by_limit_aggregate_stream.rs @@ -0,0 +1,322 @@ +use datafusion::arrow::array::{ArrayRef, AsArray, RecordBatch}; +use datafusion::arrow::compute::{lexsort_to_indices, take, SortColumn, SortOptions}; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::dfschema::internal_err; +use datafusion::error::Result as DFResult; +use datafusion::execution::{RecordBatchStream, TaskContext}; +use datafusion::logical_expr::{EmitTo, GroupsAccumulator}; +use datafusion::physical_expr::GroupsAccumulatorAdapter; +use datafusion::physical_plan::aggregates::group_values::{new_group_values, GroupValues}; +use datafusion::physical_plan::aggregates::order::GroupOrdering; +use datafusion::physical_plan::aggregates::PhysicalGroupBy; +use datafusion::physical_plan::udaf::AggregateFunctionExpr; +use datafusion::physical_plan::{ExecutionPlan, PhysicalExpr, SendableRecordBatchStream}; +use futures::ready; +use futures::stream::{Stream, StreamExt}; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use super::dict_remap::{is_int32_utf8_dict, GlobalDict}; +use super::GroupByLimitAggregateExec; + +enum ExecutionState { + ReadingInput, + ProducingOutput(RecordBatch), + Done, +} + +pub(crate) struct GroupByLimitAggregateStream { + schema: SchemaRef, + input: SendableRecordBatchStream, + aggregate_arguments: Vec>>, + filter_expressions: Vec>>, + group_by: PhysicalGroupBy, + batch_size: usize, + exec_state: ExecutionState, + input_done: bool, + accumulators: Vec>, + group_values: Box, + /// One slot per group column: `Some` for a `Dictionary(Int32, Utf8)` key grouped as Int32 + /// global ids, `None` for a column passed through to `group_values` unchanged. + dict_remaps: Vec>, + current_group_indices: Vec, + k: usize, + factor: usize, + order: Vec<(usize, SortOptions)>, +} + +impl GroupByLimitAggregateStream { + pub fn new( + agg: &GroupByLimitAggregateExec, + context: Arc, + partition: usize, + ) -> DFResult { + let agg_schema = Arc::clone(&agg.schema()); + let agg_group_by = agg.group_expr().clone(); + let agg_filter_expr = agg.filter_expr().to_vec(); + + let batch_size = context.session_config().batch_size(); + let input = agg.input().execute(partition, Arc::clone(&context))?; + + let aggregate_arguments = aggregate_expressions(agg.aggr_expr())?; + + let accumulators: Vec<_> = agg + .aggr_expr() + .iter() + .map(create_group_accumulator) + .collect::>()?; + + let group_schema = agg_group_by.group_schema(&agg.input().schema())?; + // Expose `Dictionary(Int32, Utf8)` group keys to `group_values` as plain `Int32` global ids + // (df's fast primitive path); other columns are passed through unchanged. + let mut int_fields = Vec::with_capacity(group_schema.fields().len()); + let mut dict_remaps = Vec::with_capacity(group_schema.fields().len()); + for field in group_schema.fields() { + if is_int32_utf8_dict(field.data_type()) { + int_fields.push(Arc::new(Field::new(field.name(), DataType::Int32, true))); + dict_remaps.push(Some(GlobalDict::new())); + } else { + int_fields.push(field.clone()); + dict_remaps.push(None); + } + } + let int_group_schema = Arc::new(Schema::new(int_fields)); + let group_values = new_group_values(int_group_schema, &GroupOrdering::None)?; + + Ok(GroupByLimitAggregateStream { + schema: agg_schema, + input, + aggregate_arguments, + filter_expressions: agg_filter_expr, + group_by: agg_group_by, + batch_size, + exec_state: ExecutionState::ReadingInput, + input_done: false, + accumulators, + group_values, + dict_remaps, + current_group_indices: Vec::with_capacity(batch_size), + k: agg.k(), + factor: agg.factor(), + order: agg.order().to_vec(), + }) + } +} + +impl Stream for GroupByLimitAggregateStream { + type Item = DFResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + match &self.exec_state { + ExecutionState::ReadingInput => match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if let Err(e) = self.group_aggregate_batch(batch) { + return Poll::Ready(Some(Err(e))); + } + } + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + // Input exhausted: emit the whole group table at once, then trim to top-k. + None => { + self.input_done = true; + match self.emit_all_trimmed() { + Ok(Some(batch)) => { + self.exec_state = ExecutionState::ProducingOutput(batch) + } + Ok(None) => self.exec_state = ExecutionState::Done, + Err(e) => return Poll::Ready(Some(Err(e))), + } + } + }, + + ExecutionState::ProducingOutput(batch) => { + let batch = batch.clone(); + let size = self.batch_size; + let (next_state, output) = if batch.num_rows() <= size { + (ExecutionState::Done, batch) + } else { + let remaining = batch.slice(size, batch.num_rows() - size); + let output = batch.slice(0, size); + (ExecutionState::ProducingOutput(remaining), output) + }; + self.exec_state = next_state; + return Poll::Ready(Some(Ok(output))); + } + + ExecutionState::Done => return Poll::Ready(None), + } + } + } +} + +impl RecordBatchStream for GroupByLimitAggregateStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +impl GroupByLimitAggregateStream { + /// Remap dictionary group columns to `Int32` global ids; pass other columns through unchanged. + fn remap_group_cols(&mut self, cols: &[ArrayRef]) -> DFResult> { + let mut out = Vec::with_capacity(cols.len()); + for (i, col) in cols.iter().enumerate() { + match &mut self.dict_remaps[i] { + Some(gd) => out.push(gd.remap(col)?), + None => out.push(Arc::clone(col)), + } + } + Ok(out) + } + + fn group_aggregate_batch(&mut self, batch: RecordBatch) -> DFResult<()> { + let group_by_values = evaluate_group_by(&self.group_by, &batch)?; + let input_values = evaluate_many(&self.aggregate_arguments, &batch)?; + let filter_values = evaluate_optional(&self.filter_expressions, &batch)?; + + assert_eq!(group_by_values.len(), 1, "Exactly 1 group value required"); + let group_cols = self.remap_group_cols(&group_by_values[0])?; + self.group_values + .intern(&group_cols, &mut self.current_group_indices)?; + let group_indices = &self.current_group_indices; + let total_num_groups = self.group_values.len(); + + for ((acc, values), opt_filter) in self + .accumulators + .iter_mut() + .zip(input_values.iter()) + .zip(filter_values.iter()) + { + let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()); + acc.update_batch(values, group_indices, opt_filter, total_num_groups)?; + } + Ok(()) + } + + /// Build the partial-state batch for all groups, then keep only the `k` smallest by the total + /// order when the number of groups exceeds `factor * k`. + fn emit_all_trimmed(&mut self) -> DFResult> { + if self.group_values.is_empty() { + return Ok(None); + } + let mut columns = self.group_values.emit(EmitTo::All)?; + // Convert the Int32 global ids of dictionary keys back to `Dictionary(Int32, Utf8)` so the + // emitted columns match the partial-aggregate output schema. + for (i, remap) in self.dict_remaps.iter().enumerate() { + if let Some(gd) = remap { + columns[i] = gd.rebuild(&columns[i])?; + } + } + for acc in &mut self.accumulators { + columns.extend(acc.state(EmitTo::All)?); + } + let batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + Ok(Some(self.trim_top_k(batch)?)) + } + + fn trim_top_k(&self, batch: RecordBatch) -> DFResult { + let g = batch.num_rows(); + if self.k == 0 || g <= self.factor.saturating_mul(self.k) { + return Ok(batch); + } + let sort_columns: Vec = self + .order + .iter() + .map(|(idx, options)| SortColumn { + values: Arc::clone(batch.column(*idx)), + options: Some(*options), + }) + .collect(); + let indices = lexsort_to_indices(&sort_columns, Some(self.k))?; + let columns = batch + .columns() + .iter() + .map(|c| take(c.as_ref(), &indices, None)) + .collect::, _>>()?; + Ok(RecordBatch::try_new(batch.schema(), columns)?) + } +} + +/// Partial-aggregate argument expressions, one vec per aggregate. Mirrors DataFusion's private +/// `aggregate_expressions` for `AggregateMode::Partial` only — the Final-mode column offset that +/// DataFusion's version takes is not needed here, so it is omitted. +fn aggregate_expressions( + aggr_expr: &[Arc], +) -> DFResult>>> { + Ok(aggr_expr + .iter() + .map(|agg| { + let mut result = agg.expressions(); + if let Some(ordering_req) = agg.order_bys() { + result.extend(ordering_req.iter().map(|item| Arc::clone(&item.expr))); + } + result + }) + .collect()) +} + +fn create_group_accumulator( + agg_expr: &Arc, +) -> DFResult> { + if agg_expr.groups_accumulator_supported() { + agg_expr.create_groups_accumulator() + } else { + let agg_expr_captured = Arc::clone(agg_expr); + let factory = move || agg_expr_captured.create_accumulator(); + Ok(Box::new(GroupsAccumulatorAdapter::new(factory))) + } +} + +fn evaluate(expr: &[Arc], batch: &RecordBatch) -> DFResult> { + expr.iter() + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect() +} + +fn evaluate_many( + expr: &[Vec>], + batch: &RecordBatch, +) -> DFResult>> { + expr.iter().map(|expr| evaluate(expr, batch)).collect() +} + +fn evaluate_optional( + expr: &[Option>], + batch: &RecordBatch, +) -> DFResult>> { + expr.iter() + .map(|expr| { + expr.as_ref() + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .transpose() + }) + .collect() +} + +fn evaluate_group_by( + group_by: &PhysicalGroupBy, + batch: &RecordBatch, +) -> DFResult>> { + let exprs: Vec = group_by + .expr() + .iter() + .map(|(expr, _)| { + let value = expr.evaluate(batch)?; + value.into_array(batch.num_rows()) + }) + .collect::>>()?; + + if !group_by.is_single() { + return internal_err!("GroupByLimitAggregate does not support grouping sets"); + } + + Ok(vec![exprs]) +} diff --git a/rust/cubestore/cubestore/src/queryplanner/group_by_limit_aggregate/mod.rs b/rust/cubestore/cubestore/src/queryplanner/group_by_limit_aggregate/mod.rs new file mode 100644 index 0000000000000..b8ab43d6ee618 --- /dev/null +++ b/rust/cubestore/cubestore/src/queryplanner/group_by_limit_aggregate/mod.rs @@ -0,0 +1,220 @@ +mod dict_remap; +mod group_by_limit_aggregate_stream; + +use datafusion::arrow::compute::SortOptions; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::common::stats::Precision; +use datafusion::common::Statistics; +use datafusion::error::Result as DFResult; +use datafusion::execution::TaskContext; +use datafusion::physical_expr::aggregate::AggregateFunctionExpr; +use datafusion::physical_expr::{Distribution, LexRequirement}; +use datafusion::physical_plan::execution_plan::CardinalityEffect; +use datafusion::physical_plan::metrics::MetricsSet; +use datafusion::physical_plan::{aggregates::*, InputOrderMode}; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PhysicalExpr, PlanProperties, + SendableRecordBatchStream, +}; +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +/// Worker-side partial hash aggregate that trims its output to the top-k groups by a total order, +/// so far fewer partial-state rows cross the network to the router's Final aggregate. +/// +/// This is a custom copy of DataFusion's partial hash aggregate (it reuses DF's `GroupValues` and +/// `GroupsAccumulator` building blocks but owns the consume/emit loop), so the only change required +/// in the DataFusion fork is making `new_group_values` public. The aggregation builds the whole +/// group table and, at the single final emit, keeps only the `k` smallest groups by `order` when +/// the number of groups exceeds `factor * k`; otherwise it emits all groups unchanged. +/// +/// `order` is a TOTAL order over groups (ORDER BY columns followed by the remaining group-by +/// columns), expressed as `(partial-output column index, sort options)`. A total order is required +/// for correctness: the same group key can live on multiple workers, and a consistent cut across +/// workers guarantees every partial state the router selects reaches it. +#[derive(Debug, Clone)] +pub struct GroupByLimitAggregateExec { + group_by: PhysicalGroupBy, + aggr_expr: Vec>, + filter_expr: Vec>>, + pub input: Arc, + /// Partial-aggregate output schema (group columns followed by accumulator state columns). + schema: SchemaRef, + input_schema: SchemaRef, + cache: PlanProperties, + /// Fetch count, `k = limit + offset`. + k: usize, + /// Only trim when the number of local groups exceeds `factor * k`. + factor: usize, + /// Total order over the partial output columns. + order: Vec<(usize, SortOptions)>, +} + +impl GroupByLimitAggregateExec { + /// Build a `GroupByLimitAggregateExec` from a partial hash `AggregateExec`, or `None` if it is not a + /// single-group-by partial aggregate (grouping sets and non-partial modes are not supported). + pub fn try_new_from_partial( + aggregate: &AggregateExec, + k: usize, + factor: usize, + order: Vec<(usize, SortOptions)>, + ) -> Option { + if *aggregate.mode() != AggregateMode::Partial { + return None; + } + // Sorted-prefix aggregates are handled by InlineAggregateExec; this targets the hash path. + if matches!(aggregate.input_order_mode(), InputOrderMode::Sorted) { + return None; + } + let group_by = aggregate.group_expr().clone(); + if !group_by.is_single() { + return None; + } + Some(Self { + group_by, + aggr_expr: aggregate.aggr_expr().to_vec(), + filter_expr: aggregate.filter_expr().to_vec(), + input: aggregate.input().clone(), + schema: aggregate.schema().clone(), + input_schema: aggregate.input_schema().clone(), + cache: aggregate.cache().clone(), + k, + factor, + order, + }) + } + + pub fn k(&self) -> usize { + self.k + } + + pub fn factor(&self) -> usize { + self.factor + } + + pub fn order(&self) -> &[(usize, SortOptions)] { + &self.order + } + + pub fn aggr_expr(&self) -> &[Arc] { + &self.aggr_expr + } + + pub fn filter_expr(&self) -> &[Option>] { + &self.filter_expr + } + + pub fn input(&self) -> &Arc { + &self.input + } + + pub fn group_expr(&self) -> &PhysicalGroupBy { + &self.group_by + } +} + +impl DisplayAs for GroupByLimitAggregateExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "GroupByLimitAggregateExec: k={}, factor={}, order={:?}", + self.k, self.factor, self.order + )?; + } + } + Ok(()) + } +} + +impl ExecutionPlan for GroupByLimitAggregateExec { + fn name(&self) -> &'static str { + "GroupByLimitAggregateExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn required_input_distribution(&self) -> Vec { + vec![Distribution::UnspecifiedDistribution] + } + + fn required_input_ordering(&self) -> Vec> { + vec![None] + } + + fn maintains_input_order(&self) -> Vec { + vec![false] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + Ok(Arc::new(Self { + group_by: self.group_by.clone(), + aggr_expr: self.aggr_expr.clone(), + filter_expr: self.filter_expr.clone(), + input: children[0].clone(), + schema: self.schema.clone(), + input_schema: self.input_schema.clone(), + cache: self.cache.clone(), + k: self.k, + factor: self.factor, + order: self.order.clone(), + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DFResult { + let stream = group_by_limit_aggregate_stream::GroupByLimitAggregateStream::new( + self, context, partition, + )?; + Ok(Box::pin(stream)) + } + + fn metrics(&self) -> Option { + None + } + + fn statistics(&self) -> DFResult { + // The trim keeps at most `factor * k` groups per output partition, so the output is bounded + // by that and by the input row count. Report it (inexact) instead of Absent, which makes + // downstream planners bail. `factor` is always > 0 here (the rewriter only builds this exec + // when trimming is enabled), but guard anyway. + let input_rows = self.input.statistics()?.num_rows; + let num_rows = if self.factor == 0 { + input_rows + } else { + let parts = self.cache.output_partitioning().partition_count().max(1); + let cap = self.factor.saturating_mul(self.k).saturating_mul(parts); + match input_rows { + Precision::Exact(n) | Precision::Inexact(n) => Precision::Inexact(n.min(cap)), + Precision::Absent => Precision::Inexact(cap), + } + }; + Ok(Statistics { + num_rows, + column_statistics: Statistics::unknown_column(&self.schema), + total_byte_size: Precision::Absent, + }) + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::LowerEqual + } +} diff --git a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/column_comparator.rs b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/column_comparator.rs index e085381ed2736..691f75b5a1816 100644 --- a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/column_comparator.rs +++ b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/column_comparator.rs @@ -1,5 +1,7 @@ use datafusion::arrow::array::*; +use datafusion::arrow::compute::SortOptions; use datafusion::arrow::datatypes::*; +use std::cmp::Ordering; use std::marker::PhantomData; /// Trait for comparing adjacent rows in an array to detect group boundaries. @@ -193,6 +195,82 @@ where } } +/// Comparator for dictionary-encoded columns (e.g. `Dictionary(Int32, Utf8)`). +/// +/// The hot path compares dictionary keys (small integers) instead of the underlying +/// values. Within a single batch all rows share one dictionary, so key equality implies +/// value equality. The reverse does not hold when a dictionary carries duplicate values +/// (e.g. after a merge unions several local dictionaries), so when adjacent keys differ we +/// fall back to comparing the actual values to avoid splitting a group incorrectly. That +/// fallback only fires on group boundaries, which are rare in a sorted stream. +pub struct DictionaryComparator { + _phantom: PhantomData K>, +} + +impl DictionaryComparator { + pub fn new() -> Self { + Self { + _phantom: PhantomData, + } + } +} + +impl ColumnComparator + for DictionaryComparator +{ + #[inline] + fn compare_adjacent(&self, col: &ArrayRef, equal_results: &mut [bool]) { + let array = col + .as_any() + .downcast_ref::>() + .expect("DictionaryComparator got non-dictionary array"); + let keys = array.keys(); + let values = array.values(); + + if !NULLABLE { + // A non-nullable field must not carry null keys; the loop below skips null checks. + debug_assert_eq!( + keys.null_count(), + 0, + "DictionaryComparator<_, false> received null keys" + ); + } + + // Built lazily, only when adjacent keys actually differ. The values array is always one + // of the types accepted by `new_dictionary_group_column`, all of which `make_comparator` + // supports. + let mut value_cmp: Option = None; + + for i in 0..equal_results.len() { + if !equal_results[i] { + continue; + } + + if NULLABLE { + let null1 = keys.is_null(i); + let null2 = keys.is_null(i + 1); + if null1 || null2 { + // Both null => same group; one null => boundary. + equal_results[i] = null1 && null2; + continue; + } + } + + let k1 = keys.value(i).as_usize(); + let k2 = keys.value(i + 1).as_usize(); + if k1 == k2 { + continue; + } + + let cmp = value_cmp.get_or_insert_with(|| { + make_comparator(values.as_ref(), values.as_ref(), SortOptions::default()) + .expect("make_comparator for dictionary values") + }); + equal_results[i] = cmp(k1, k2) == Ordering::Equal; + } + } +} + /// Instantiate a primitive comparator and push it into the vector. /// /// Handles const generic NULLABLE parameter based on field nullability. @@ -260,3 +338,84 @@ macro_rules! instantiate_byte_view_comparator { } }; } + +/// Instantiate a dictionary comparator and push it into the vector. +#[macro_export] +macro_rules! instantiate_dictionary_comparator { + ($v:expr, $nullable:expr, $k:ty) => { + if $nullable { + $v.push(Box::new( + $crate::queryplanner::inline_aggregate::column_comparator::DictionaryComparator::< + $k, + true, + >::new(), + ) as _) + } else { + $v.push(Box::new( + $crate::queryplanner::inline_aggregate::column_comparator::DictionaryComparator::< + $k, + false, + >::new(), + ) as _) + } + }; +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + fn run(comparator: &dyn ColumnComparator, col: &ArrayRef) -> Vec { + let n = col.len(); + let mut eq = vec![true; n - 1]; + comparator.compare_adjacent(col, &mut eq); + eq + } + + #[test] + fn dict_compare_same_dictionary_sorted() { + // values [a,b,c], keys [0,0,1,1,2] => rows a,a,b,b,c + let dict: DictionaryArray = vec!["a", "a", "b", "b", "c"].into_iter().collect(); + let col: ArrayRef = Arc::new(dict); + let cmp = DictionaryComparator::::new(); + assert_eq!(run(&cmp, &col), vec![true, false, true, false]); + } + + #[test] + fn dict_compare_duplicate_values_fallback() { + // Dictionary with duplicate values: keys 0 and 1 both map to "a". + // Adjacent keys differ but values are equal -> must NOT be a boundary. + let keys = Int32Array::from(vec![0, 1, 2]); + let values = Arc::new(StringArray::from(vec!["a", "a", "b"])); + let dict = DictionaryArray::::new(keys, values); + let col: ArrayRef = Arc::new(dict); + let cmp = DictionaryComparator::::new(); + // rows: a, a, b => (a,a) equal via fallback, (a,b) boundary + assert_eq!(run(&cmp, &col), vec![true, false]); + } + + #[test] + fn dict_compare_nulls() { + // rows: null, null, "a", "a", null + let dict: DictionaryArray = vec![None, None, Some("a"), Some("a"), None] + .into_iter() + .collect(); + let col: ArrayRef = Arc::new(dict); + let cmp = DictionaryComparator::::new(); + // (null,null) equal, (null,a) boundary, (a,a) equal, (a,null) boundary + assert_eq!(run(&cmp, &col), vec![true, false, true, false]); + } + + #[test] + fn dict_compare_respects_short_circuit() { + // values [a,b], keys [0,0,1]; pre-mark first pair as already-false. + let dict: DictionaryArray = vec!["a", "a", "b"].into_iter().collect(); + let col: ArrayRef = Arc::new(dict); + let cmp = DictionaryComparator::::new(); + let mut eq = vec![false, true]; + cmp.compare_adjacent(&col, &mut eq); + // first stays false (short-circuit), second is a real boundary a->b + assert_eq!(eq, vec![false, false]); + } +} diff --git a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/dictionary_group_column.rs b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/dictionary_group_column.rs new file mode 100644 index 0000000000000..c39c8b7c49d94 --- /dev/null +++ b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/dictionary_group_column.rs @@ -0,0 +1,142 @@ +use std::marker::PhantomData; + +use datafusion::arrow::array::{new_null_array, Array, ArrayRef, DictionaryArray}; +use datafusion::arrow::datatypes::{ + ArrowDictionaryKeyType, ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, Int8Type, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use datafusion::dfschema::not_impl_err; +use datafusion::error::Result as DFResult; +use datafusion::physical_expr::binary_map::OutputType; +use datafusion::physical_plan::aggregates::group_values::multi_group_by::{ + ByteGroupValueBuilder, GroupColumn, +}; + +/// A [`GroupColumn`] for dictionary-encoded columns that stores the group values in their +/// decoded form (delegating to an inner byte-array builder) while accepting dictionary input. +/// +/// Group storage operations (`append_val`/`equal_to`) only happen on group boundaries, so they +/// resolve the dictionary value on demand: a non-null row delegates to the inner builder using +/// `(dict.values(), dict.key(row))`, and a null row delegates against a cached single-null array. +/// The per-row hot path stays in `DictionaryComparator`, which never touches this builder. +pub struct DictionaryGroupColumn { + inner: Box, + /// One-element null array of the dictionary's value type, used to append/compare null keys. + null_row: ArrayRef, + _k: PhantomData K>, +} + +impl DictionaryGroupColumn { + fn new(inner: Box, null_row: ArrayRef) -> Self { + Self { + inner, + null_row, + _k: PhantomData, + } + } + + #[inline] + fn dict(column: &ArrayRef) -> &DictionaryArray { + column + .as_any() + .downcast_ref::>() + .expect("DictionaryGroupColumn got non-dictionary array") + } +} + +impl GroupColumn for DictionaryGroupColumn { + fn equal_to(&self, lhs_row: usize, column: &ArrayRef, rhs_row: usize) -> bool { + let dict = Self::dict(column); + if dict.is_null(rhs_row) { + self.inner.equal_to(lhs_row, &self.null_row, 0) + } else { + let key = dict.keys().value(rhs_row).as_usize(); + self.inner.equal_to(lhs_row, dict.values(), key) + } + } + + fn append_val(&mut self, column: &ArrayRef, row: usize) { + let dict = Self::dict(column); + if dict.is_null(row) { + self.inner.append_val(&self.null_row, 0); + } else { + let key = dict.keys().value(row).as_usize(); + self.inner.append_val(dict.values(), key); + } + } + + // Scalar fallbacks, not a fast path: on the sorted/inline path the column comparator does the + // hot row-by-row work, so these per-row loops are not on the critical path. A vectorized + // implementation would only matter if this column were used by the hash aggregate. + fn vectorized_equal_to( + &self, + lhs_rows: &[usize], + array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut [bool], + ) { + for i in 0..lhs_rows.len() { + if equal_to_results[i] { + equal_to_results[i] = self.equal_to(lhs_rows[i], array, rhs_rows[i]); + } + } + } + + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) { + for &row in rows { + self.append_val(array, row); + } + } + + fn len(&self) -> usize { + self.inner.len() + } + + fn size(&self) -> usize { + self.inner.size() + self.null_row.get_array_memory_size() + } + + fn build(self: Box) -> ArrayRef { + (*self).inner.build() + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + self.inner.take_n(n) + } +} + +/// Builds a [`DictionaryGroupColumn`] for the given dictionary key/value types. +/// +/// The inner builder stores the decoded value (Utf8/Binary); the wrapper is generic over the +/// key type so it can read keys without decoding the whole batch. +pub fn new_dictionary_group_column( + key_type: &DataType, + value_type: &DataType, +) -> DFResult> { + let inner: Box = match value_type { + DataType::Utf8 => Box::new(ByteGroupValueBuilder::::new(OutputType::Utf8)), + DataType::LargeUtf8 => Box::new(ByteGroupValueBuilder::::new(OutputType::Utf8)), + DataType::Binary => Box::new(ByteGroupValueBuilder::::new(OutputType::Binary)), + DataType::LargeBinary => Box::new(ByteGroupValueBuilder::::new(OutputType::Binary)), + other => { + return not_impl_err!( + "dictionary value type {other} not supported in SortedGroupValues" + ) + } + }; + let null_row = new_null_array(value_type, 1); + + Ok(match key_type { + DataType::Int8 => Box::new(DictionaryGroupColumn::::new(inner, null_row)), + DataType::Int16 => Box::new(DictionaryGroupColumn::::new(inner, null_row)), + DataType::Int32 => Box::new(DictionaryGroupColumn::::new(inner, null_row)), + DataType::Int64 => Box::new(DictionaryGroupColumn::::new(inner, null_row)), + DataType::UInt8 => Box::new(DictionaryGroupColumn::::new(inner, null_row)), + DataType::UInt16 => Box::new(DictionaryGroupColumn::::new(inner, null_row)), + DataType::UInt32 => Box::new(DictionaryGroupColumn::::new(inner, null_row)), + DataType::UInt64 => Box::new(DictionaryGroupColumn::::new(inner, null_row)), + other => { + return not_impl_err!("dictionary key type {other} not supported in SortedGroupValues") + } + }) +} diff --git a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs index 8a58d1a8c0dba..3ed078cb3683c 100644 --- a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs +++ b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs @@ -1,4 +1,5 @@ mod column_comparator; +mod dictionary_group_column; mod inline_aggregate_stream; mod sorted_group_values; mod sorted_group_values_rows; @@ -279,31 +280,48 @@ fn supported_schema(schema: &datafusion::arrow::datatypes::Schema) -> bool { /// /// Types not in this list will use the row-based [`SortedGroupValuesRows`] implementation fn supported_type(data_type: &DataType) -> bool { - matches!( - *data_type, + match data_type { DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) - | DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Binary - | DataType::LargeBinary - | DataType::Date32 - | DataType::Date64 - | DataType::Time32(_) - | DataType::Time64(_) - | DataType::Timestamp(_, _) - | DataType::Utf8View - | DataType::BinaryView - ) + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Utf8View + | DataType::BinaryView => true, + // Dictionary group columns handled by DictionaryGroupColumn + DictionaryComparator. + DataType::Dictionary(key_type, value_type) => { + matches!( + key_type.as_ref(), + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + ) && matches!( + value_type.as_ref(), + DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary + ) + } + _ => false, + } } #[cfg(test)] diff --git a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/sorted_group_values.rs b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/sorted_group_values.rs index e7c0e82b2f7cb..08c49a93a7fd6 100644 --- a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/sorted_group_values.rs +++ b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/sorted_group_values.rs @@ -22,9 +22,10 @@ use datafusion::physical_plan::aggregates::group_values::multi_group_by::{ use datafusion::physical_plan::aggregates::group_values::GroupValues; use crate::queryplanner::inline_aggregate::column_comparator::ColumnComparator; +use crate::queryplanner::inline_aggregate::dictionary_group_column::new_dictionary_group_column; use crate::{ instantiate_byte_array_comparator, instantiate_byte_view_comparator, - instantiate_primitive_comparator, + instantiate_dictionary_comparator, instantiate_primitive_comparator, }; pub struct SortedGroupValues { @@ -319,6 +320,52 @@ impl GroupValues for SortedGroupValues { v.push(Box::new(b) as _); instantiate_byte_view_comparator!(comparators, nullable, BinaryViewType); } + &DataType::Dictionary(ref key_type, ref value_type) => { + v.push(new_dictionary_group_column(key_type, value_type)?); + match key_type.as_ref() { + DataType::Int8 => { + instantiate_dictionary_comparator!(comparators, nullable, Int8Type) + } + DataType::Int16 => { + instantiate_dictionary_comparator!(comparators, nullable, Int16Type) + } + DataType::Int32 => { + instantiate_dictionary_comparator!(comparators, nullable, Int32Type) + } + DataType::Int64 => { + instantiate_dictionary_comparator!(comparators, nullable, Int64Type) + } + DataType::UInt8 => { + instantiate_dictionary_comparator!(comparators, nullable, UInt8Type) + } + DataType::UInt16 => { + instantiate_dictionary_comparator!( + comparators, + nullable, + UInt16Type + ) + } + DataType::UInt32 => { + instantiate_dictionary_comparator!( + comparators, + nullable, + UInt32Type + ) + } + DataType::UInt64 => { + instantiate_dictionary_comparator!( + comparators, + nullable, + UInt64Type + ) + } + dt => { + return not_impl_err!( + "dictionary key type {dt} not supported in SortedGroupValues" + ) + } + } + } dt => return not_impl_err!("{dt} not supported in SortedGroupValues"), } } @@ -390,3 +437,187 @@ impl GroupValues for SortedGroupValues { self.equal_to_results.clear(); } } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::{Int32Array, StringArray}; + use datafusion::arrow::datatypes::{Field, Schema}; + use std::sync::Arc; + + fn dict_schema(nullable: bool) -> SchemaRef { + Arc::new(Schema::new(vec![Field::new( + "g", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + nullable, + )])) + } + + fn decode(dict: &ArrayRef) -> Vec> { + let dict = dict + .as_any() + .downcast_ref::>() + .unwrap(); + let values = dict + .values() + .as_any() + .downcast_ref::() + .unwrap(); + (0..dict.len()) + .map(|i| { + if dict.is_null(i) { + None + } else { + Some(values.value(dict.keys().value(i) as usize).to_string()) + } + }) + .collect() + } + + /// Groups must continue across a batch boundary even when the two batches carry different + /// local dictionaries (the same string is encoded with different keys per batch). + #[test] + fn sorted_group_values_dictionary_cross_batch() { + let mut gv = SortedGroupValues::try_new(dict_schema(false)).unwrap(); + + // Batch 1: a, a, b (values [a,b], keys [0,0,1]) + let b1 = datafusion::arrow::array::DictionaryArray::::new( + Int32Array::from(vec![0, 0, 1]), + Arc::new(StringArray::from(vec!["a", "b"])), + ); + let mut groups = vec![]; + gv.intern(&[Arc::new(b1) as ArrayRef], &mut groups).unwrap(); + assert_eq!(groups, vec![0, 0, 1]); + + // Batch 2: b, c with a DIFFERENT local dictionary (values [b,c], keys [0,1]). + let b2 = datafusion::arrow::array::DictionaryArray::::new( + Int32Array::from(vec![0, 1]), + Arc::new(StringArray::from(vec!["b", "c"])), + ); + gv.intern(&[Arc::new(b2) as ArrayRef], &mut groups).unwrap(); + // "b" continues the last group (idx 1), "c" opens group 2. + assert_eq!(groups, vec![1, 2]); + + assert_eq!(gv.len(), 3); + let out = gv.emit(EmitTo::All).unwrap(); + assert_eq!( + decode(&out[0]), + vec![ + Some("a".to_string()), + Some("b".to_string()), + Some("c".to_string()) + ] + ); + } + + /// Isolated timing: dictionary vs Utf8 group keys over a sorted 10-column stream. + /// Run with: cargo test -p cubestore --lib sorted_group_values_dict_vs_utf8_bench -- --ignored --nocapture + #[test] + #[ignore] + fn sorted_group_values_dict_vs_utf8_bench() { + use std::time::Instant; + const NCOLS: usize = 10; + const ROWS: usize = 2_000_000; + const BATCH: usize = 8192; + const ROWS_PER_GROUP: usize = 20; // ~100k groups, low per-column cardinality + + // tuple value for column j of group g, c0 most significant -> stream is sorted ascending + let val = |g: usize, j: usize| -> String { + let digit = (g / 4usize.pow((NCOLS - 1 - j) as u32)) % 4; + format!("c{j}_{digit}") + }; + + // Build Utf8 batches and Dictionary batches for the same sorted data. + let mut utf8_batches: Vec> = vec![]; + let mut dict_batches: Vec> = vec![]; + let mut row = 0usize; + while row < ROWS { + let n = BATCH.min(ROWS - row); + let mut utf8_cols: Vec = Vec::with_capacity(NCOLS); + let mut dict_cols: Vec = Vec::with_capacity(NCOLS); + for j in 0..NCOLS { + let vals: Vec = + (0..n).map(|i| val((row + i) / ROWS_PER_GROUP, j)).collect(); + let strs: Vec<&str> = vals.iter().map(|s| s.as_str()).collect(); + utf8_cols.push(Arc::new(StringArray::from(strs.clone())) as ArrayRef); + let dict: datafusion::arrow::array::DictionaryArray = + strs.into_iter().collect(); + dict_cols.push(Arc::new(dict) as ArrayRef); + } + utf8_batches.push(utf8_cols); + dict_batches.push(dict_cols); + row += n; + } + + let utf8_schema = Arc::new(Schema::new( + (0..NCOLS) + .map(|j| Field::new(format!("c{j}"), DataType::Utf8, false)) + .collect::>(), + )); + let dict_schema = Arc::new(Schema::new( + (0..NCOLS) + .map(|j| { + Field::new( + format!("c{j}"), + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ) + }) + .collect::>(), + )); + + let run = |schema: SchemaRef, batches: &Vec>| -> (u128, usize) { + let mut gv = SortedGroupValues::try_new(schema).unwrap(); + let mut groups = vec![]; + let t0 = Instant::now(); + for cols in batches { + gv.intern(cols, &mut groups).unwrap(); + } + (t0.elapsed().as_micros(), gv.len()) + }; + + // warm + measure (best of 3) + let mut utf8_us = u128::MAX; + let mut dict_us = u128::MAX; + let mut ngroups = 0; + for _ in 0..3 { + let (u, gu) = run(utf8_schema.clone(), &utf8_batches); + let (d, gd) = run(dict_schema.clone(), &dict_batches); + assert_eq!(gu, gd, "group counts must match"); + ngroups = gu; + utf8_us = utf8_us.min(u); + dict_us = dict_us.min(d); + } + println!( + "intern over {ROWS} rows x {NCOLS} cols, {ngroups} groups:\n Utf8: {:.1} ms\n Dict: {:.1} ms\n speedup: {:.2}x", + utf8_us as f64 / 1000.0, + dict_us as f64 / 1000.0, + utf8_us as f64 / dict_us as f64, + ); + } + + /// Null keys form their own group and continue across batches. + #[test] + fn sorted_group_values_dictionary_nulls() { + let mut gv = SortedGroupValues::try_new(dict_schema(true)).unwrap(); + + // rows: null, null, a + let b1: datafusion::arrow::array::DictionaryArray = + vec![None, None, Some("a")].into_iter().collect(); + let mut groups = vec![]; + gv.intern(&[Arc::new(b1) as ArrayRef], &mut groups).unwrap(); + assert_eq!(groups, vec![0, 0, 1]); + + // rows: a, b -> "a" continues group 1, "b" new + let b2: datafusion::arrow::array::DictionaryArray = + vec![Some("a"), Some("b")].into_iter().collect(); + gv.intern(&[Arc::new(b2) as ArrayRef], &mut groups).unwrap(); + assert_eq!(groups, vec![1, 2]); + + let out = gv.emit(EmitTo::All).unwrap(); + assert_eq!( + decode(&out[0]), + vec![None, Some("a".to_string()), Some("b".to_string())] + ); + } +} diff --git a/rust/cubestore/cubestore/src/queryplanner/mod.rs b/rust/cubestore/cubestore/src/queryplanner/mod.rs index c4d4742312e4d..0a2536a34a3f7 100644 --- a/rust/cubestore/cubestore/src/queryplanner/mod.rs +++ b/rust/cubestore/cubestore/src/queryplanner/mod.rs @@ -10,6 +10,7 @@ use datafusion_datasource::memory::MemorySourceConfig; use datafusion_datasource::source::DataSourceExec; pub use planning::PlanningMeta; mod check_memory; +mod group_by_limit_aggregate; pub mod physical_plan_flags; pub mod pretty_printers; mod projection_above_limit; @@ -161,6 +162,7 @@ impl QueryPlanner for QueryPlannerImpl { inline_tables, self.cache.clone(), state.clone(), + self.config.dictionary_encoding_enabled(), ); let query_planner = SqlToRel::new_with_options(&schema_provider, sql_to_rel_options()); @@ -231,6 +233,7 @@ impl QueryPlanner for QueryPlannerImpl { logical_plan, &self.meta_store.as_ref(), self.config.enable_topk(), + self.config.dictionary_encoding_enabled(), ) .await?; let workers = compute_workers( @@ -372,6 +375,7 @@ struct MetaStoreSchemaProvider { inline_tables: InlineTables, cache: Arc, config_options: ConfigOptions, + dictionary_encoding: bool, expr_planners: Vec>, // session_state.expr_planners clone session_state: Arc, } @@ -408,6 +412,7 @@ impl MetaStoreSchemaProvider { inline_tables: &InlineTables, cache: Arc, session_state: Arc, + dictionary_encoding: bool, ) -> Self { let by_name = tables.iter().map(|t| TableKey(t)).collect(); Self { @@ -418,6 +423,7 @@ impl MetaStoreSchemaProvider { cache, inline_tables: (*inline_tables).clone(), config_options: ConfigOptions::new(), + dictionary_encoding, expr_planners: datafusion::execution::FunctionRegistry::expr_planners( session_state.as_ref(), ), @@ -486,13 +492,14 @@ impl ContextProvider for MetaStoreSchemaProvider { .get(&TableKey(&table_path)) .map(|table| -> Arc { let table = unsafe { &*table.0 }; + let dictionary_encoding = self.dictionary_encoding; let schema = Arc::new(Schema::new( table .table .get_row() .get_columns() .iter() - .map(|c| c.clone().into()) + .map(|c| c.as_arrow_field(dictionary_encoding)) .collect::>(), )); Arc::new(CubeTableLogical { @@ -1095,6 +1102,7 @@ pub mod tests { &vec![], Arc::new(SqlResultCache::new(1 << 20, None, 10000, None)), Arc::new(SessionContext::new().state()), + false, ) } diff --git a/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs b/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs index 5246b1878f132..07904c0fc67a0 100644 --- a/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs +++ b/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs @@ -1,5 +1,6 @@ use crate::cluster::WorkerPlanningParams; use crate::queryplanner::check_memory::CheckMemoryExec; +use crate::queryplanner::group_by_limit_aggregate::GroupByLimitAggregateExec; use crate::queryplanner::inline_aggregate::{InlineAggregateExec, InlineAggregateMode}; use crate::queryplanner::planning::WorkerExec; use crate::queryplanner::query_executor::ClusterSendExec; @@ -286,13 +287,19 @@ pub fn push_sorted_partial_aggregate_below_merge( /// key) would drop partial states of tied groups. pub fn push_worker_sort_and_limit( p: Arc, + group_by_limit_factor: usize, ) -> Result, DataFusionError> { - // Worker side: wrap the partial aggregate with a per-partition bounded sort. + // Worker side: bound the partial aggregate's output. The sorted (inline) aggregate is always + // bounded with a per-partition Sort(fetch) -- the group count is unknown, so we can't trim, we + // just sort. The hash aggregate uses the trimming top-k, engaged only when factor > 0. + // `resort_worker_subtree` returns None when nothing applies (hash with factor == 0), leaving the + // plan as planned. if let Some(w) = p.as_any().downcast_ref::() { let Some((cols, fetch)) = w.worker_sort_and_limit.clone() else { return Ok(p); }; - let Some(new_input) = resort_worker_subtree(&w.input, &cols, fetch) else { + let Some(new_input) = resort_worker_subtree(&w.input, &cols, fetch, group_by_limit_factor) + else { return Ok(p); }; return Ok(Arc::new(WorkerExec::new( @@ -307,10 +314,8 @@ pub fn push_worker_sort_and_limit( ))); } - // Router side: rebuild the final aggregate over a sort-preserving merge in `worker_order`, and - // reorder the (optimization-only) worker subtree to match. Same keys are adjacent in - // `worker_order`, so the sorted final combines them; its output is `worker_order`-sorted (whose - // prefix is the query's ORDER BY), so the limit above stays correct. + // Router side: combine the workers' top-k with a hash final aggregate, then re-apply the top-k + // sort by `worker_order` (the total order T). let Some(final_agg) = FinalAggregateInfo::extract(&p) else { return Ok(p); }; @@ -324,23 +329,56 @@ pub fn push_worker_sort_and_limit( let Some((cols, fetch)) = cs.worker_sort_and_limit.clone() else { return Ok(p); }; - let Some(new_worker_subtree) = resort_worker_subtree(&cs.input_for_optimizations, &cols, fetch) - else { + // The hash aggregate emits unordered trimmed groups (combined by a hash final + re-sort); the + // sorted/inline aggregate emits bounded sorted streams (combined by a sort-preserving merge + + // sorted final). Decide from the worker's partial aggregate before rebuilding. + let is_hash = locate_partial_aggregate(&cs.input_for_optimizations) + .map_or(false, |partial| partial.as_any().is::()); + let Some(new_worker_subtree) = resort_worker_subtree( + &cs.input_for_optimizations, + &cols, + fetch, + group_by_limit_factor, + ) else { return Ok(p); }; let new_cs: Arc = Arc::new(cs.with_changed_schema(new_worker_subtree, cs.required_input_ordering.clone())); let worker_order = worker_ordering(&final_agg.group_expr, &cols)?; - let merged: Arc = - Arc::new(SortPreservingMergeExec::new(worker_order, new_cs)); - Ok(Arc::new(AggregateExec::try_new( - AggregateMode::Final, - final_agg.group_expr, - final_agg.aggr_expr, - final_agg.filter_expr, - merged, - final_agg.input_schema, - )?)) + + if is_hash { + // Hash final over coalesced (unordered) streams, then re-apply the top-k sort by the total + // order T. The Sort(fetch) is required even for a bare LIMIT: it keeps the k smallest by T -- + // exactly the groups every worker kept and fully combined here -- where a plain limit could + // take a group only one worker kept (undercounted). The coalesce drains the workers in + // parallel. + let coalesced: Arc = Arc::new(CoalescePartitionsExec::new(new_cs)); + let final_hash: Arc = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + final_agg.group_expr, + final_agg.aggr_expr, + final_agg.filter_expr, + coalesced, + final_agg.input_schema, + )?); + Ok(Arc::new( + SortExec::new(worker_order, final_hash).with_fetch(Some(fetch)), + )) + } else { + // Sorted final over a sort-preserving merge in worker_order: equal keys are adjacent so the + // sorted final combines them, and its output stays worker_order-sorted (whose prefix is the + // query's ORDER BY), so the query's limit above stays correct -- no extra Sort needed. + let merged: Arc = + Arc::new(SortPreservingMergeExec::new(worker_order, new_cs)); + Ok(Arc::new(AggregateExec::try_new( + AggregateMode::Final, + final_agg.group_expr, + final_agg.aggr_expr, + final_agg.filter_expr, + merged, + final_agg.input_schema, + )?)) + } } /// Builds the `worker_order` LexOrdering over an aggregate's group columns from the descriptor. @@ -365,33 +403,84 @@ fn worker_ordering( Ok(LexOrdering::new(exprs)) } -/// Rebuilds a worker subtree as `SortPreservingMerge(worker_order) <- Sort(worker_order, fetch, per -/// partition) <- partial`. Returns `None` for an unrecognized or already-rewritten subtree, which -/// keeps [push_worker_sort_and_limit] idempotent. +/// Rebuilds a worker subtree to bound its output to the top `fetch` groups by the total order in +/// `cols`. Two shapes, by partial aggregate kind: +/// - hash (`AggregateExec`): `CoalescePartitions <- GroupByLimitAggregate` -- trim during +/// aggregation, emitted unsorted for the router's hash final. Only when `factor > 0`; returns +/// `None` otherwise (trimming disabled, leave the plan as planned). +/// - sorted/inline: `SortPreservingMerge(T) <- Sort(T, fetch, per partition) <- PartialAggregate` -- +/// we can't trim a sorted aggregate and don't know the group count, so always bound with a sort. /// -/// The per-partition `Sort` does the bounding (a bounded heap, O(fetch) memory); the merge above it -/// carries no fetch. Because this pass runs last, `replace_suboptimal_merge_sorts` has already run -/// and won't push the query's row limit into the merge -- which would cut the merged stream of -/// (still uncombined) partial rows by rows and undercount groups split across partitions. +/// Returns `None` for an unrecognized subtree (no locatable partial aggregate). fn resort_worker_subtree( worker_subtree: &Arc, cols: &[(usize, bool, bool)], fetch: usize, + group_by_limit_factor: usize, ) -> Option> { let partial = locate_partial_aggregate(worker_subtree)?; - let schema = partial.schema(); - let mut exprs = Vec::with_capacity(cols.len()); - for (idx, asc, nulls_first) in cols { - let field = schema.fields().get(*idx)?; - exprs.push(PhysicalSortExpr { - expr: Arc::new(Column::new(field.name(), *idx)), - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, + + // Hash path: trim during aggregation, emit unsorted for the router's hash final. The factor + // gates whether trimming applies; with it off, nothing applies on this path. + if partial.as_any().is::() { + if group_by_limit_factor == 0 { + return None; + } + + // Per-partition (CUBESTORE_GROUP_BY_LIMIT_PER_PARTITION): drop the merge below the aggregate + // so it runs over every raw partition; the CoalescePartitions below then parallelizes all of + // them (N-way) instead of one stream per union branch. + let partial = if per_partition_enabled() { + partial + .as_any() + .downcast_ref::() + .and_then(|agg| { + let new_input = strip_leading_coalesce_partitions(agg.input()); + partial.clone().with_new_children(vec![new_input]).ok() + }) + .unwrap_or(partial) + } else { + partial + }; + + let order: Vec<(usize, SortOptions)> = cols + .iter() + .map(|(idx, asc, nulls_first)| { + ( + *idx, + SortOptions { + descending: !asc, + nulls_first: *nulls_first, + }, + ) + }) + .collect(); + let trimmed: Arc = match partial.as_any().downcast_ref::() + { + Some(agg) => match GroupByLimitAggregateExec::try_new_from_partial( + agg, + fetch, + group_by_limit_factor, + order, + ) { + Some(e) => Arc::new(e), + None => partial, }, - }); + None => partial, + }; + + // Emit the trimmed top-k unsorted, coalesced to one stream. The router hash-combines and + // re-applies the top-k sort, so no worker-side sort/merge holds the whole result and the + // per-partition aggregates run in parallel (CoalescePartitions spawns a task per input) + // instead of being drained one at a time by a sort-preserving merge. + return Some(Arc::new(CoalescePartitionsExec::new(trimmed))); } - let worker_order = LexOrdering::new(exprs); + + // Sorted/inline path: bound each partition with Sort(fetch) and merge in the total order. The + // per-partition `fetch` is sound because the key is the full group key: a globally top-`fetch` + // group stays within every partition's first `fetch`, so the router's sorted final still sees + // all its partial states (see this function's doc and the module note on the total order). + let worker_order = lex_ordering_from_cols(cols, &partial.schema())?; let per_partition_sort: Arc = Arc::new( SortExec::new(worker_order.clone(), partial) .with_fetch(Some(fetch)) @@ -403,6 +492,52 @@ fn resort_worker_subtree( ))) } +/// Build a `LexOrdering` over the partial aggregate's group columns from the descriptor (indices +/// into the partial output schema). Returns `None` if a column index is out of range. +fn lex_ordering_from_cols( + cols: &[(usize, bool, bool)], + schema: &datafusion::arrow::datatypes::SchemaRef, +) -> Option { + let mut exprs = Vec::with_capacity(cols.len()); + for (idx, asc, nulls_first) in cols { + let field = schema.fields().get(*idx)?; + exprs.push(PhysicalSortExpr { + expr: Arc::new(Column::new(field.name(), *idx)), + options: SortOptions { + descending: !asc, + nulls_first: *nulls_first, + }, + }); + } + Some(LexOrdering::new(exprs)) +} + +/// Toggle (CUBESTORE_GROUP_BY_LIMIT_PER_PARTITION): drop the merge below the trimmed aggregate so it +/// runs per partition (N-way) instead of per union branch. The `CoalescePartitions` then +/// parallelizes all partitions. +/// +/// Trades memory for parallelism: one group table per DF partition (= per parquet file and chunk), +/// so peak memory is bounded by the partition count, not by k. It stays low only while partitions +/// are key-local (group keys do not span partitions) -- true for index-sorted pre-aggregations. On +/// a group key spread across many partitions each table approaches full cardinality (~N x memory). +fn per_partition_enabled() -> bool { + std::env::var("CUBESTORE_GROUP_BY_LIMIT_PER_PARTITION") + .map(|v| v == "true" || v == "1") + .unwrap_or(false) +} + +/// Peel the leading `CoalescePartitionsExec` chain directly feeding the aggregate, exposing the +/// underlying multi-partition streams. Only the immediate coalesce(s) are removed; any +/// `CoalescePartitionsExec` deeper in the subtree (e.g. one a child inserted to satisfy its own +/// single-partition input requirement, as in a UNION branch) is left intact so plan semantics are +/// preserved. +fn strip_leading_coalesce_partitions(p: &Arc) -> Arc { + if p.as_any().is::() { + return strip_leading_coalesce_partitions(&p.children()[0]); + } + p.clone() +} + /// The group/aggregate state of either an `InlineAggregateExec` or a plain `AggregateExec`, when in /// Final mode. struct FinalAggregateInfo { diff --git a/rust/cubestore/cubestore/src/queryplanner/optimizations/group_by_limit_rewriter.rs b/rust/cubestore/cubestore/src/queryplanner/optimizations/group_by_limit_rewriter.rs new file mode 100644 index 0000000000000..0b17807bf873b --- /dev/null +++ b/rust/cubestore/cubestore/src/queryplanner/optimizations/group_by_limit_rewriter.rs @@ -0,0 +1,252 @@ +use crate::queryplanner::group_by_limit_aggregate::GroupByLimitAggregateExec; +use crate::queryplanner::planning::WorkerExec; +use crate::queryplanner::query_executor::ClusterSendExec; +use datafusion::arrow::compute::SortOptions; +use datafusion::error::DataFusionError; +use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode}; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::expressions::Column; +use datafusion::physical_plan::limit::GlobalLimitExec; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion::physical_plan::{ExecutionPlan, InputOrderMode}; +use std::sync::Arc; + +/// Trim the worker-side partial hash aggregate to the top-k groups when the plan is +/// `LIMIT k` over `ORDER BY ` over a distributed hash aggregate. +/// +/// Correctness requires a TOTAL order over groups (`T = ORDER BY ++ remaining group-by columns`, +/// in group-by order) applied in TWO places that must agree: +/// - the worker cut: each worker keeps its local top-k by `T`; +/// - the router select: the global Sort + Limit must also order by `T`. +/// Under `T` the router's top-k equals the global top-k by `T`, and every worker that holds a +/// partial state for such a group keeps it (its local rank can only be smaller), so every needed +/// partial state reaches the router. Ordering the router by `T` instead of the bare `ORDER BY` does +/// not change the query contract: `ORDER BY` is a prefix of `T`, so the output stays validly +/// ordered and the previously-unspecified tie order just becomes deterministic. +/// +/// We only rewrite when the plan matches exactly `Sort(/Limit) -> [passthrough] -> Final aggregate +/// -> [passthrough/cluster boundary] -> Partial hash aggregate`; anything else on the path (a +/// HAVING filter, a nested aggregate, a computed projection) makes us bail, so we never trim a plan +/// where the limit does not directly govern this aggregate. +/// +/// `factor` gates trimming at runtime (only when local groups exceed `factor * k`); `0` disables. +pub fn replace_with_group_by_limit_aggregate( + plan: Arc, + factor: usize, +) -> Result, DataFusionError> { + if factor == 0 { + return Ok(plan); + } + let Some(target) = analyze(&plan) else { + return Ok(plan); + }; + apply(plan, &target, factor) +} + +struct Target { + /// The router `SortExec` whose ordering must be extended to the total order. + sort: Arc, + /// The worker-side partial hash `AggregateExec` to replace with a trimming exec. + partial: Arc, + /// Tail of the total order to append to the router sort (over the sort's input schema). + router_tail: Vec, + /// Full total order over the partial output schema for the worker cut. + trim_order: Vec<(usize, SortOptions)>, + /// `k = limit + offset`. + k: usize, +} + +fn analyze(root: &Arc) -> Option { + // Peel an optional top GlobalLimit (carries the offset), then require a SortExec. + let (skip, extra_fetch, sort_node) = + if let Some(gl) = root.as_any().downcast_ref::() { + (gl.skip(), gl.fetch(), child(root)?) + } else { + (0, None, root.clone()) + }; + let sort = sort_node.as_any().downcast_ref::()?; + let order: Vec = sort.expr().iter().cloned().collect(); + if order.is_empty() { + return None; + } + // The worker must keep enough groups to cover `limit + offset`. When a top GlobalLimit carries + // the offset, DataFusion already folds `skip + limit` into the sort's fetch, so prefer it; + // otherwise fall back to the GlobalLimit's own `skip + fetch`. + let k = sort + .fetch() + .or_else(|| extra_fetch.map(|fetch| skip + fetch))?; + + // Sort -> [passthrough] -> Final aggregate. + let final_agg_node = descend_to_final_aggregate(sort.input().clone())?; + let final_agg = final_agg_node.as_any().downcast_ref::()?; + + // Final aggregate -> [passthrough/boundary] -> Partial hash aggregate. + let partial_node = descend_to_worker_partial(final_agg.input().clone())?; + let partial = partial_node.as_any().downcast_ref::()?; + if !partial.group_expr().is_single() + || matches!(partial.input_order_mode(), InputOrderMode::Sorted) + { + return None; + } + + let num_group_cols = partial.group_expr().output_exprs().len(); + if num_group_cols == 0 { + return None; + } + let partial_schema = partial.schema(); + let group_names: Vec = partial_schema + .fields() + .iter() + .take(num_group_cols) + .map(|f| f.name().clone()) + .collect(); + + // Map ORDER BY columns onto group-by columns (by name; robust to projections). + let mut used = vec![false; num_group_cols]; + let mut trim_order: Vec<(usize, SortOptions)> = Vec::with_capacity(num_group_cols); + for e in &order { + let column = e.expr.as_any().downcast_ref::()?; + let idx = group_names.iter().position(|n| n == column.name())?; + // A repeated ORDER BY column adds nothing to the total order; skip it. + if used[idx] { + continue; + } + used[idx] = true; + trim_order.push((idx, e.options)); + } + if trim_order.is_empty() { + return None; + } + + // Totalize: append the remaining group-by columns in group-by order. Build the matching tail + // for the router sort over its own (Final-output) schema, resolved by name. + let sort_input_schema = sort.input().schema(); + let mut router_tail: Vec = Vec::new(); + for (idx, is_used) in used.into_iter().enumerate() { + if is_used { + continue; + } + let name = &group_names[idx]; + let options = SortOptions::default(); + let sort_col_idx = sort_input_schema.index_of(name).ok()?; + router_tail.push(PhysicalSortExpr { + expr: Arc::new(Column::new(name, sort_col_idx)), + options, + }); + trim_order.push((idx, options)); + } + + Some(Target { + sort: sort_node, + partial: partial_node, + router_tail, + trim_order, + k, + }) +} + +fn apply( + node: Arc, + target: &Target, + factor: usize, +) -> Result, DataFusionError> { + let is_sort = Arc::ptr_eq(&node, &target.sort); + let is_partial = Arc::ptr_eq(&node, &target.partial); + + let new_children = node + .children() + .into_iter() + .map(|c| apply(c.clone(), target, factor)) + .collect::, _>>()?; + let node = node.with_new_children(new_children)?; + + if is_partial { + if let Some(agg) = node.as_any().downcast_ref::() { + if let Some(exec) = GroupByLimitAggregateExec::try_new_from_partial( + agg, + target.k, + factor, + target.trim_order.clone(), + ) { + return Ok(Arc::new(exec)); + } + } + // Leaving the full aggregate in place stays correct; the router still sorts by the total + // order, it just receives every group instead of the trimmed top-k. + return Ok(node); + } + + if is_sort { + if let Some(sort) = node.as_any().downcast_ref::() { + let mut exprs: Vec = sort.expr().iter().cloned().collect(); + exprs.extend(target.router_tail.iter().cloned()); + let new_sort = SortExec::new(LexOrdering::new(exprs), sort.input().clone()) + .with_preserve_partitioning(sort.preserve_partitioning()) + .with_fetch(sort.fetch()); + return Ok(Arc::new(new_sort)); + } + } + + Ok(node) +} + +/// Walk down single-child passthrough nodes (which preserve rows and grouping) until the first +/// `Final`/`FinalPartitioned` `AggregateExec`. Returns `None` if a non-passthrough node is hit +/// first (e.g. a filter or a computed projection). +fn descend_to_final_aggregate(mut node: Arc) -> Option> { + loop { + if let Some(agg) = node.as_any().downcast_ref::() { + return matches!( + agg.mode(), + AggregateMode::Final | AggregateMode::FinalPartitioned + ) + .then_some(node.clone()); + } + if is_row_passthrough(&node) { + node = child(&node)?; + } else { + return None; + } + } +} + +/// Walk down passthrough nodes from a `Final` aggregate's input to the worker-side `Partial` +/// aggregate, requiring that exactly one `ClusterSend`/`Worker` boundary is crossed. Returns `None` +/// if anything unexpected (a second aggregate, a filter, ...) is on the path. +fn descend_to_worker_partial(mut node: Arc) -> Option> { + let mut crossed_boundary = false; + loop { + if let Some(agg) = node.as_any().downcast_ref::() { + return (crossed_boundary && *agg.mode() == AggregateMode::Partial) + .then_some(node.clone()); + } + if node.as_any().is::() || node.as_any().is::() { + crossed_boundary = true; + node = child(&node)?; + } else if is_row_passthrough(&node) { + node = child(&node)?; + } else { + return None; + } + } +} + +/// Single-child nodes that pass rows through unchanged (preserving grouping), so a limit/sort above +/// them governs the aggregate below them. +fn is_row_passthrough(node: &Arc) -> bool { + let any = node.as_any(); + any.is::() + || any.is::() + || any.is::() + || any.is::() +} + +fn child(node: &Arc) -> Option> { + let children = node.children(); + if children.len() != 1 { + return None; + } + Some(children[0].clone()) +} diff --git a/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs b/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs index edf902b44d3e5..2df406a1a62af 100644 --- a/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs +++ b/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs @@ -1,5 +1,6 @@ mod check_memory; mod distributed_partial_aggregate; +mod group_by_limit_rewriter; mod inline_aggregate_rewriter; pub mod is_not_distinct_from_join_keys; pub mod rewrite_plan; @@ -13,6 +14,7 @@ use crate::queryplanner::optimizations::distributed_partial_aggregate::{ push_aggregate_to_workers, push_sorted_partial_aggregate_below_merge, push_worker_sort_and_limit, replace_suboptimal_merge_sorts, }; +use crate::queryplanner::optimizations::group_by_limit_rewriter::replace_with_group_by_limit_aggregate; use crate::queryplanner::optimizations::inline_aggregate_rewriter::replace_with_inline_aggregate; use crate::queryplanner::planning::CubeExtensionPlanner; use crate::queryplanner::pretty_printers::{pp_phys_plan_ext, PPOptions}; @@ -43,6 +45,7 @@ pub struct CubeQueryPlanner { serialized_plan: Arc, memory_handler: Arc, data_loaded_size: Option>, + group_by_limit_factor: usize, } impl CubeQueryPlanner { @@ -50,6 +53,7 @@ impl CubeQueryPlanner { cluster: Arc, serialized_plan: Arc, memory_handler: Arc, + group_by_limit_factor: usize, ) -> CubeQueryPlanner { CubeQueryPlanner { cluster: Some(cluster), @@ -57,6 +61,7 @@ impl CubeQueryPlanner { serialized_plan, memory_handler, data_loaded_size: None, + group_by_limit_factor, } } @@ -65,6 +70,7 @@ impl CubeQueryPlanner { worker_planning_params: WorkerPlanningParams, memory_handler: Arc, data_loaded_size: Option>, + group_by_limit_factor: usize, ) -> CubeQueryPlanner { CubeQueryPlanner { serialized_plan, @@ -72,6 +78,7 @@ impl CubeQueryPlanner { worker_partition_count: Some(worker_planning_params), memory_handler, data_loaded_size, + group_by_limit_factor, } } } @@ -104,6 +111,7 @@ impl QueryPlanner for CubeQueryPlanner { self.memory_handler.clone(), self.data_loaded_size.clone(), ctx_state.config().options(), + self.group_by_limit_factor, ); result } @@ -112,12 +120,14 @@ impl QueryPlanner for CubeQueryPlanner { #[derive(Debug)] pub struct PreOptimizeRule { push_partial_aggregate_below_merge: bool, + group_by_limit_factor: usize, } impl PreOptimizeRule { - pub fn new(push_partial_aggregate_below_merge: bool) -> Self { + pub fn new(push_partial_aggregate_below_merge: bool, group_by_limit_factor: usize) -> Self { Self { push_partial_aggregate_below_merge, + group_by_limit_factor, } } } @@ -128,7 +138,11 @@ impl PhysicalOptimizerRule for PreOptimizeRule { plan: Arc, _config: &ConfigOptions, ) -> datafusion::common::Result> { - pre_optimize_physical_plan(plan, self.push_partial_aggregate_below_merge) + pre_optimize_physical_plan( + plan, + self.push_partial_aggregate_below_merge, + self.group_by_limit_factor, + ) } fn name(&self) -> &str { @@ -143,6 +157,7 @@ impl PhysicalOptimizerRule for PreOptimizeRule { fn pre_optimize_physical_plan( p: Arc, push_partial_aggregate_below_merge: bool, + group_by_limit_factor: usize, ) -> Result, DataFusionError> { let p = rewrite_physical_plan(p, &mut |p| push_aggregate_to_workers(p))?; @@ -164,6 +179,11 @@ fn pre_optimize_physical_plan( // Replace sorted AggregateExec with InlineAggregateExec for better performance let p = rewrite_physical_plan(p, &mut |p| replace_with_inline_aggregate(p))?; + // Trim the worker-side partial hash aggregate to the top-k groups when the query orders by a + // subset of group-by columns and has a limit. Runs after inline-aggregate replacement so it + // only sees the remaining (hash) partial aggregates. + let p = replace_with_group_by_limit_aggregate(p, group_by_limit_factor)?; + Ok(p) } @@ -173,6 +193,7 @@ fn finalize_physical_plan( memory_handler: Arc, data_loaded_size: Option>, config: &ConfigOptions, + group_by_limit_factor: usize, ) -> Result, DataFusionError> { let p = rewrite_physical_plan(p, &mut |p| add_check_memory_exec(p, memory_handler.clone()))?; log::trace!( @@ -201,7 +222,9 @@ fn finalize_physical_plan( // Last: bound worker memory for ORDER BY LIMIT that isn't an index prefix. Runs // after replace_suboptimal_merge_sorts so it doesn't push the query's row limit into the // worker merge we add (which would cut uncombined partial rows and undercount). - let p = rewrite_physical_plan(p, &mut |p| push_worker_sort_and_limit(p))?; + let p = rewrite_physical_plan(p, &mut |p| { + push_worker_sort_and_limit(p, group_by_limit_factor) + })?; log::trace!( "Rewrote physical plan by push_worker_sort_and_limit:\n{}", pp_phys_plan_ext(p.as_ref(), &PPOptions::show_nonmeta()) diff --git a/rust/cubestore/cubestore/src/queryplanner/planning.rs b/rust/cubestore/cubestore/src/queryplanner/planning.rs index 9c9cd7b352df6..c406eb596e10c 100644 --- a/rust/cubestore/cubestore/src/queryplanner/planning.rs +++ b/rust/cubestore/cubestore/src/queryplanner/planning.rs @@ -85,7 +85,7 @@ pub async fn choose_index( p: LogicalPlan, metastore: &dyn PlanIndexStore, ) -> Result<(LogicalPlan, PlanningMeta), DataFusionError> { - choose_index_ext(p, metastore, true).await + choose_index_ext(p, metastore, true, false).await } /// Information required to distribute the logical plan into multiple workers. @@ -123,6 +123,7 @@ pub async fn choose_index_ext( p: LogicalPlan, metastore: &dyn PlanIndexStore, enable_topk: bool, + dictionary_encoding: bool, ) -> Result<(LogicalPlan, PlanningMeta), DataFusionError> { // Prepare information to choose the index. let mut collector = CollectConstraints::default(); @@ -238,6 +239,7 @@ pub async fn choose_index_ext( chosen_indices: &indices, next_index: 0, enable_topk, + dictionary_encoding, can_pushdown_limit: true, cluster_send_next_id: 1, }; @@ -831,6 +833,7 @@ struct ChooseIndex<'a> { next_index: usize, chosen_indices: &'a [IndexSnapshot], enable_topk: bool, + dictionary_encoding: bool, can_pushdown_limit: bool, cluster_send_next_id: usize, } @@ -1057,6 +1060,7 @@ impl ChooseIndex<'_> { HashMap::new(), Vec::new(), NoopParquetMetadataCache::new(), + self.dictionary_encoding, )?))); let index_schema = source.schema(); @@ -1143,19 +1147,22 @@ impl ChooseIndex<'_> { return None; } let limit = ctx.limit?; - let sort = ctx.sort.as_ref().filter(|s| !s.is_empty())?; let group_by = ctx.group_by.as_ref().filter(|g| !g.is_empty())?; - // Every ORDER BY column must be a group-by column; map it to its group-key position. + // Every ORDER BY column must be a group-by column; map it to its group-key position. A bare + // LIMIT (no ORDER BY) leaves the prefix empty, so the total order is the full group key in + // group-by order -- "any n groups" becomes "the n smallest by group key", still valid. let mut cols: Vec<(usize, bool, bool)> = Vec::with_capacity(group_by.len()); let mut used = vec![false; group_by.len()]; - for name in sort { - let idx = group_by.iter().position(|g| g == name)?; - if used[idx] { - continue; + if let Some(sort) = ctx.sort.as_ref().filter(|s| !s.is_empty()) { + for name in sort { + let idx = group_by.iter().position(|g| g == name)?; + if used[idx] { + continue; + } + used[idx] = true; + cols.push((idx, ctx.sort_is_asc, !ctx.sort_is_asc)); } - used[idx] = true; - cols.push((idx, ctx.sort_is_asc, !ctx.sort_is_asc)); } // Extend with the remaining group keys to make it a total order on the full group key. for (idx, is_used) in used.iter().enumerate() { @@ -1833,6 +1840,7 @@ fn pull_up_cluster_send(mut p: LogicalPlan) -> Result Result Result { let lsend; diff --git a/rust/cubestore/cubestore/src/queryplanner/pretty_printers.rs b/rust/cubestore/cubestore/src/queryplanner/pretty_printers.rs index 63d386add4951..3b50620742de7 100644 --- a/rust/cubestore/cubestore/src/queryplanner/pretty_printers.rs +++ b/rust/cubestore/cubestore/src/queryplanner/pretty_printers.rs @@ -28,6 +28,7 @@ use std::sync::Arc; use crate::queryplanner::check_memory::CheckMemoryExec; use crate::queryplanner::filter_by_key_range::FilterByKeyRangeExec; +use crate::queryplanner::group_by_limit_aggregate::GroupByLimitAggregateExec; use crate::queryplanner::inline_aggregate::{InlineAggregateExec, InlineAggregateMode}; use crate::queryplanner::merge_sort::LastRowByUniqueKeyExec; use crate::queryplanner::panic::{PanicWorkerExec, PanicWorkerNode}; @@ -617,6 +618,16 @@ fn pp_phys_plan_indented(p: &dyn ExecutionPlan, indent: usize, o: &PPOptions, ou if let Some(limit) = agg.limit() { *out += &format!(", limit: {}", limit) } + } else if let Some(agg) = a.downcast_ref::() { + *out += &format!( + "GroupByLimitAggregate, k: {}, factor: {}, order: {:?}", + agg.k(), + agg.factor(), + agg.order() + ); + if o.show_aggregations { + *out += &format!(", aggs: {:?}", agg.aggr_expr()) + } } else if let Some(l) = a.downcast_ref::() { *out += &format!("LocalLimit, n: {}", l.fetch()); } else if let Some(l) = a.downcast_ref::() { diff --git a/rust/cubestore/cubestore/src/queryplanner/query_executor.rs b/rust/cubestore/cubestore/src/queryplanner/query_executor.rs index 528f268f7048f..093b0a95fbaf2 100644 --- a/rust/cubestore/cubestore/src/queryplanner/query_executor.rs +++ b/rust/cubestore/cubestore/src/queryplanner/query_executor.rs @@ -29,7 +29,7 @@ use datafusion::arrow::array::{ Float64Array, Int16Array, Int32Array, Int64Array, MutableArrayData, NullArray, StringArray, TimestampMicrosecondArray, TimestampNanosecondArray, UInt16Array, UInt32Array, UInt64Array, }; -use datafusion::arrow::compute::{filter_record_batch, SortOptions}; +use datafusion::arrow::compute::{cast, filter_record_batch, SortOptions}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use datafusion::arrow::ipc::reader::StreamReader; use datafusion::arrow::ipc::writer::StreamWriter; @@ -39,7 +39,9 @@ use datafusion::common::ToDFSchema; use datafusion::config::TableParquetOptions; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; -use datafusion::datasource::physical_plan::parquet::get_reader_options_customizer; +use datafusion::datasource::physical_plan::parquet::{ + get_reader_options_customizer, ReaderOptionsCustomizer, +}; use datafusion::datasource::physical_plan::{ FileScanConfig, ParquetFileReaderFactory, ParquetSource, }; @@ -50,6 +52,7 @@ use datafusion::execution::memory_pool::{MemoryPool, MemoryReservation}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::execution::TaskContext; use datafusion::logical_expr::{Expr, LogicalPlan}; +use datafusion::parquet::arrow::arrow_reader::ArrowReaderOptions; use datafusion::physical_expr; use datafusion::physical_expr::LexOrdering; use datafusion::physical_expr::{ @@ -226,7 +229,7 @@ pub struct QueryExecutorImpl { metadata_cache_factory: Arc, parquet_metadata_cache: Arc, memory_handler: Arc, - push_partial_aggregate_below_merge: bool, + config: Arc, } crate::di_service!(QueryExecutorImpl, [QueryExecutor]); @@ -547,13 +550,13 @@ impl QueryExecutorImpl { metadata_cache_factory: Arc, parquet_metadata_cache: Arc, memory_handler: Arc, - push_partial_aggregate_below_merge: bool, + config: Arc, ) -> Arc { Arc::new(QueryExecutorImpl { metadata_cache_factory, parquet_metadata_cache, memory_handler, - push_partial_aggregate_below_merge, + config, }) } @@ -567,6 +570,7 @@ impl QueryExecutorImpl { cluster, serialized_plan, self.memory_handler.clone(), + self.config.group_by_limit_factor(), )) } @@ -582,6 +586,7 @@ impl QueryExecutorImpl { worker_planning_params, self.memory_handler.clone(), data_loaded_size.clone(), + self.config.group_by_limit_factor(), )) } @@ -603,7 +608,8 @@ impl QueryExecutorImpl { vec![ // Cube rules Arc::new(PreOptimizeRule::new( - self.push_partial_aggregate_below_merge, + self.config.push_partial_aggregate_below_merge_enabled(), + self.config.group_by_limit_factor(), )), // DF rules without EnforceDistribution. We do need to keep EnforceSorting. Arc::new(OutputRequirements::new_add_mode()), @@ -650,6 +656,26 @@ impl QueryExecutorImpl { } } +/// Forces the parquet reader to produce arrays of the supplied schema directly, so +/// `Dictionary(Int32, Utf8)` string columns are read natively from the on-disk dictionary pages +/// instead of being materialized as `Utf8` and cast to dictionary by the schema adapter. +#[derive(Debug)] +struct SuppliedSchemaReaderCustomizer { + schema: SchemaRef, + inner: Arc, +} + +impl ReaderOptionsCustomizer for SuppliedSchemaReaderCustomizer { + fn adjust_reader_options( + &self, + options: ArrowReaderOptions, + ) -> Result { + // Compose with the configured customizer so its adjustments are kept, then pin the schema. + let options = self.inner.adjust_reader_options(options)?; + Ok(options.with_schema(Arc::clone(&self.schema))) + } +} + #[derive(Clone, Serialize, Deserialize)] pub struct CubeTable { index_snapshot: IndexSnapshot, @@ -679,6 +705,7 @@ impl CubeTable { remote_to_local_names: HashMap, worker_partition_ids: Vec<(u64, RowFilter)>, parquet_metadata_cache: Arc, + dictionary_encoding: bool, ) -> Result { let schema = Arc::new(Schema::new( // Tables are always exposed only using table columns order instead of index one because @@ -690,7 +717,7 @@ impl CubeTable { .get_row() .get_columns() .iter() - .map(|c| c.clone().into()) + .map(|c| c.as_arrow_field(dictionary_encoding)) .collect::>(), )); Ok(Self { @@ -830,6 +857,22 @@ impl CubeTable { None }; + // With dictionary encoding the index schema carries `Dictionary` string columns; supply it to + // the parquet reader so they are read natively from the dictionary pages instead of being + // materialized as `Utf8` and cast to dictionary per batch. + let reader_options_customizer: Arc = if index_schema + .fields() + .iter() + .any(|f| matches!(f.data_type(), DataType::Dictionary(_, _))) + { + Arc::new(SuppliedSchemaReaderCustomizer { + schema: index_schema.clone(), + inner: get_reader_options_customizer(state.config()), + }) + } else { + get_reader_options_customizer(state.config()) + }; + let unique_key_columns = self .index_snapshot .table_path @@ -879,9 +922,8 @@ impl CubeTable { let mut options = TableParquetOptions::new(); options.global = state.config_options().execution.parquet.clone(); - let parquet_source = - ParquetSource::new(options, get_reader_options_customizer(state.config())) - .with_parquet_file_reader_factory(self.parquet_metadata_cache.clone()); + let parquet_source = ParquetSource::new(options, reader_options_customizer.clone()) + .with_parquet_file_reader_factory(self.parquet_metadata_cache.clone()); let parquet_source = if let Some(phys_pred) = &physical_predicate { parquet_source.with_predicate(index_schema.clone(), phys_pred.clone()) } else { @@ -928,18 +970,21 @@ impl CubeTable { "Record batch for in memory chunk {:?} is not provided", chunk )))?; - if let Some(batch) = record_batches.iter().next() { - if batch.schema() != index_schema { - return Err(CubeError::internal(format!( - "Index schema {:?} and in memory chunk schema {:?} mismatch", - index_schema, - record_batches[0].schema() - ))); - } - } + // In-memory chunks are written with plain column types. When dictionary encoding + // is on the index schema exposes string columns as Dictionary, so cast the chunk + // batches to the index schema (Utf8 -> Dictionary) instead of rejecting them -- + // this keeps the memory scan consistent with the dictionary parquet partitions + // and feeds the dict-aware aggregate. + let record_batches = match record_batches.iter().next() { + Some(batch) if batch.schema() != index_schema => record_batches + .iter() + .map(|b| cast_record_batch_to_schema(b, &index_schema)) + .collect::, _>>()?, + _ => record_batches.clone(), + }; Arc::new(DataSourceExec::new(Arc::new( MemorySourceConfig::try_new( - &[record_batches.clone()], + &[record_batches], index_schema.clone(), index_projection_or_none_on_schema_match.clone(), )? @@ -960,7 +1005,7 @@ impl CubeTable { let mut options = TableParquetOptions::new(); options.global = state.config_options().execution.parquet.clone(); let parquet_source = - ParquetSource::new(options, get_reader_options_customizer(state.config())) + ParquetSource::new(options, reader_options_customizer.clone()) .with_parquet_file_reader_factory(self.parquet_metadata_cache.clone()); let parquet_source = if let Some(phys_pred) = &physical_predicate { parquet_source.with_predicate(index_schema.clone(), phys_pred.clone()) @@ -2148,11 +2193,68 @@ macro_rules! convert_array { }}; } +/// Cast a record batch's columns to `schema`'s field types. Used to bring in-memory chunk batches +/// (written with plain types) up to the index schema, whose string columns are `Dictionary` when +/// dictionary encoding is on. +fn cast_record_batch_to_schema( + batch: &RecordBatch, + schema: &SchemaRef, +) -> Result { + // Columns are cast positionally, so the batch must align with the schema. + debug_assert_eq!(batch.num_columns(), schema.fields().len()); + let mut columns = Vec::with_capacity(batch.num_columns()); + for (column, field) in batch.columns().iter().zip(schema.fields()) { + if column.data_type() == field.data_type() { + columns.push(Arc::clone(column)); + } else { + columns.push(cast(column, field.data_type())?); + } + } + Ok(RecordBatch::try_new(Arc::clone(schema), columns)?) +} + +/// Cast any `Dictionary` columns of a batch to their value type, leaving other columns untouched. +fn decode_dictionary_columns(batch: RecordBatch) -> Result { + let schema = batch.schema(); + if !schema + .fields() + .iter() + .any(|f| matches!(f.data_type(), DataType::Dictionary(_, _))) + { + return Ok(batch); + } + let mut fields = Vec::with_capacity(schema.fields().len()); + let mut columns = Vec::with_capacity(batch.num_columns()); + for (field, column) in schema.fields().iter().zip(batch.columns()) { + match field.data_type() { + DataType::Dictionary(_, value_type) => { + columns.push(cast(column, value_type)?); + fields.push(Arc::new(Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ))); + } + _ => { + columns.push(column.clone()); + fields.push(field.clone()); + } + } + } + Ok(RecordBatch::try_new( + Arc::new(Schema::new(fields)), + columns, + )?) +} + pub fn batches_to_dataframe(batches: Vec) -> Result { let mut cols = vec![]; let mut all_rows = vec![]; for batch in batches.into_iter() { + // Dictionary-encoded columns (string group keys) reach the result as `Dictionary(_, _)`; + // decode them to their value type for the row conversion below. + let batch = decode_dictionary_columns(batch)?; if cols.len() == 0 { for (i, field) in batch.schema().fields().iter().enumerate() { cols.push(Column::new( @@ -2523,7 +2625,75 @@ fn slice_copy(a: &dyn Array, start: usize, len: usize) -> ArrayRef { #[cfg(test)] mod tests { use super::*; - use datafusion::arrow::datatypes::Field; + use datafusion::arrow::array::DictionaryArray; + use datafusion::arrow::datatypes::{Field, Int32Type}; + + #[test] + fn test_cast_record_batch_to_dictionary_schema() -> Result<(), CubeError> { + // An in-memory chunk batch (plain Utf8) cast up to a dictionary-encoded index schema: the + // string column becomes Dictionary(Int32, Utf8) preserving values and nulls; other columns + // are untouched. + let src = Arc::new(Schema::new(vec![ + Field::new("s", DataType::Utf8, true), + Field::new("n", DataType::Int64, true), + ])); + let batch = RecordBatch::try_new( + src, + vec![ + Arc::new(StringArray::from(vec![ + Some("b"), + Some("a"), + None, + Some("b"), + ])) as ArrayRef, + Arc::new(Int64Array::from(vec![1, 2, 3, 4])) as ArrayRef, + ], + )?; + let target = Arc::new(Schema::new(vec![ + Field::new( + "s", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + Field::new("n", DataType::Int64, true), + ])); + + let out = cast_record_batch_to_schema(&batch, &target)?; + assert_eq!(out.schema(), target); + let dict = out + .column(0) + .as_any() + .downcast_ref::>() + .unwrap(); + let vals = dict + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let got: Vec> = dict + .keys() + .iter() + .map(|k| k.map(|k| vals.value(k as usize).to_string())) + .collect(); + assert_eq!( + got, + vec![ + Some("b".to_string()), + Some("a".to_string()), + None, + Some("b".to_string()) + ] + ); + assert_eq!( + out.column(1) + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &[1, 2, 3, 4] + ); + Ok(()) + } #[test] fn test_batch_to_dataframe() -> Result<(), CubeError> { diff --git a/rust/cubestore/cubestore/src/queryplanner/topk/util.rs b/rust/cubestore/cubestore/src/queryplanner/topk/util.rs index ed84d9a524e22..6a176713ebe73 100644 --- a/rust/cubestore/cubestore/src/queryplanner/topk/util.rs +++ b/rust/cubestore/cubestore/src/queryplanner/topk/util.rs @@ -1,4 +1,5 @@ -use datafusion::arrow::array::ArrayBuilder; +use datafusion::arrow::array::{ArrayBuilder, StringDictionaryBuilder}; +use datafusion::arrow::datatypes::{DataType, Int32Type}; use datafusion::error::DataFusionError; use datafusion::scalar::ScalarValue; @@ -50,8 +51,42 @@ macro_rules! cube_match_scalar { }}; } +/// Dictionary group keys (CubeStore dictionary encoding produces `Dictionary(Int32, Utf8)`) are not +/// handled by [cube_match_scalar], which works on plain scalar variants. Build/append into a string +/// dictionary builder so the top-k result columns keep the dictionary type of the schema. +fn create_dictionary_builder(key_type: &DataType, value: &ScalarValue) -> Box { + match (key_type, value) { + (DataType::Int32, ScalarValue::Utf8(_)) => { + Box::new(StringDictionaryBuilder::::new()) + } + _ => panic!( + "Unhandled dictionary topk type: key={:?} value={:?}", + key_type, value + ), + } +} + +fn append_dictionary_value( + b: &mut dyn ArrayBuilder, + value: &ScalarValue, +) -> Result<(), DataFusionError> { + let b = b + .as_any_mut() + .downcast_mut::>() + .expect("expected StringDictionaryBuilder"); + match value { + ScalarValue::Utf8(None) => b.append_null(), + ScalarValue::Utf8(Some(v)) => b.append_value(v), + other => panic!("Unhandled dictionary topk value: {:?}", other), + } + Ok(()) +} + #[allow(unused_variables)] pub fn create_builder(s: &ScalarValue) -> Box { + if let ScalarValue::Dictionary(key_type, value) = s { + return create_dictionary_builder(key_type, value); + } macro_rules! create_list_builder { ($v: expr, $inner_data_type: expr, ListBuilder $(, $rest: tt)*) => {{ panic!("nested lists not supported") @@ -84,6 +119,9 @@ pub(crate) fn append_value( b: &mut dyn ArrayBuilder, v: &ScalarValue, ) -> Result<(), DataFusionError> { + if let ScalarValue::Dictionary(_key_type, value) = v { + return append_dictionary_value(b, value); + } let b = b.as_any_mut(); macro_rules! append_list_value { ($list: expr, $dummy: expr, $inner_data_type: expr, ListBuilder $(, $rest: tt)*) => {{ diff --git a/rust/cubestore/cubestore/src/sql/mod.rs b/rust/cubestore/cubestore/src/sql/mod.rs index 472465ade9279..4a444123b6501 100644 --- a/rust/cubestore/cubestore/src/sql/mod.rs +++ b/rust/cubestore/cubestore/src/sql/mod.rs @@ -6487,7 +6487,9 @@ mod tests { } } - // Test 4: ORDER BY 1 DESC with LIMIT on non-prefix column + // Test 4: ORDER BY DESC + LIMIT on a non-prefix column, grouped by a + // non-prefix column (hash aggregate). The hash path bounds the worker + // output with the trimming aggregate, not a Sort. { let result = service .exec_query( @@ -6504,8 +6506,9 @@ mod tests { _ => panic!("expected string"), }; assert!( - worker_plan.contains("Sort, fetch: 2"), - "Worker should have Sort with fetch=2 for DESC. Plan: {}", + worker_plan.contains("GroupByLimitAggregate"), + "Hash-aggregate worker should bound output with \ + GroupByLimitAggregate. Plan: {}", worker_plan ); } diff --git a/rust/cubestore/cubestore/src/store/compaction.rs b/rust/cubestore/cubestore/src/store/compaction.rs index 9576f2ed0cdec..fc138a4a76527 100644 --- a/rust/cubestore/cubestore/src/store/compaction.rs +++ b/rust/cubestore/cubestore/src/store/compaction.rs @@ -2785,6 +2785,247 @@ mod tests { .await; } + #[tokio::test] + async fn dictionary_encoding_native_read_group_by() { + Config::test("dictionary_encoding_native_read_group_by") + .update_config(|mut c| { + c.dictionary_encoding_enabled = true; + c + }) + .start_test(async move |services| { + let service = services.sql_service; + let compaction_service = services + .injector + .get_service_typed::() + .await; + service + .exec_query("CREATE SCHEMA d") + .await + .unwrap() + .collect() + .await + .unwrap(); + // A string group key alongside Timestamp and Decimal columns: with dictionary encoding + // the parquet reader is handed the dictionary schema, whose strict per-field check must + // accept every column type (not just the dictionary one). + service + .exec_query("CREATE TABLE d.t (s text, ts timestamp, dec decimal(10,2), n int)") + .await + .unwrap() + .collect() + .await + .unwrap(); + service + .exec_query( + "INSERT INTO d.t (s, ts, dec, n) VALUES \ + ('b', '2020-01-01T00:00:00.000Z', 1.50, 10), \ + ('a', '2020-01-02T00:00:00.000Z', 2.50, 5), \ + ('b', '2020-01-03T00:00:00.000Z', 3.50, 7), \ + ('c', '2020-01-04T00:00:00.000Z', 4.50, 1), \ + ('a', '2020-01-05T00:00:00.000Z', 2.00, 2)", + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + // Persist the in-memory chunk into the partition's parquet main table so the scan goes + // through the native dictionary reader rather than the in-memory chunk. + compaction_service + .compact(1, DataLoadedSize::new()) + .await + .unwrap(); + let partitions = services + .meta_store + .get_active_partitions_by_index_id(1) + .await + .unwrap(); + assert_eq!(partitions.len(), 1); + assert_eq!(partitions[0].get_row().main_table_row_count(), 5); + + let result = service + .exec_query("SELECT s, sum(n) FROM d.t GROUP BY 1 ORDER BY 1") + .await + .unwrap() + .collect() + .await + .unwrap(); + assert_eq!( + result.get_rows(), + &vec![ + Row::new(vec![ + TableValue::String("a".to_string()), + TableValue::Int(7) + ]), + Row::new(vec![ + TableValue::String("b".to_string()), + TableValue::Int(17) + ]), + Row::new(vec![ + TableValue::String("c".to_string()), + TableValue::Int(1) + ]), + ] + ); + Ok::<(), CubeError>(()) + }) + .await; + } + + #[tokio::test] + async fn dictionary_encoding_topk_by_aggregate() { + // ORDER BY DESC LIMIT goes through the ClusterAggregateTopK path, whose scalar + // helpers (cube_match_scalar) had no dictionary arm and panicked when the group key was + // dictionary-encoded. The top-k result must materialize the dictionary group key correctly. + Config::test("dictionary_encoding_topk_by_aggregate") + .update_config(|mut c| { + c.dictionary_encoding_enabled = true; + c + }) + .start_test(async move |services| { + let service = services.sql_service; + let compaction_service = services + .injector + .get_service_typed::() + .await; + service + .exec_query("CREATE SCHEMA d") + .await + .unwrap() + .collect() + .await + .unwrap(); + service + .exec_query("CREATE TABLE d.t (s text, n int)") + .await + .unwrap() + .collect() + .await + .unwrap(); + // Includes a NULL group key landing in the top-k, to exercise the dictionary + // builder's null path in the result. + service + .exec_query( + "INSERT INTO d.t (s, n) VALUES \ + ('a', 5), ('b', 10), ('b', 7), ('c', 1), ('a', 2), ('d', 100), (NULL, 50)", + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + compaction_service + .compact(1, DataLoadedSize::new()) + .await + .unwrap(); + + // groups: a=7, b=17, c=1, d=100, NULL=50 -> top 3 by sum desc: d=100, NULL=50, b=17 + let result = service + .exec_query("SELECT s, sum(n) FROM d.t GROUP BY 1 ORDER BY 2 DESC LIMIT 3") + .await + .unwrap() + .collect() + .await + .unwrap(); + assert_eq!( + result.get_rows(), + &vec![ + Row::new(vec![ + TableValue::String("d".to_string()), + TableValue::Int(100) + ]), + Row::new(vec![TableValue::Null, TableValue::Int(50)]), + Row::new(vec![ + TableValue::String("b".to_string()), + TableValue::Int(17) + ]), + ] + ); + Ok::<(), CubeError>(()) + }) + .await; + } + + #[tokio::test] + async fn dictionary_encoding_in_memory_group_by() { + Config::test("dictionary_encoding_in_memory_group_by") + .update_config(|mut c| { + c.dictionary_encoding_enabled = true; + c + }) + .start_test(async move |services| { + let service = services.sql_service; + service + .exec_query("CREATE SCHEMA d") + .await + .unwrap() + .collect() + .await + .unwrap(); + // A unique-key table routes inserts through the streaming in-memory chunk path, so + // the scan reads them via the memory source (not parquet). `s` is a non-key string + // dimension we group by. + service + .exec_query("CREATE TABLE d.t (id int, s text, n int) unique key (id)") + .await + .unwrap() + .collect() + .await + .unwrap(); + service + .exec_query( + "INSERT INTO d.t (id, s, n, __seq) VALUES \ + (1, 'b', 10, 1), (2, 'a', 5, 2), (3, 'b', 7, 3), \ + (4, 'c', 1, 4), (5, 'a', 2, 5)", + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + // No compaction: the data stays in an in-memory chunk (Utf8), while the dictionary + // index schema exposes `s` as Dictionary. The memory scan must cast it instead of + // rejecting the mismatch. + let chunks = services + .meta_store + .get_chunks_by_partition(1, false) + .await + .unwrap(); + assert!( + !chunks.is_empty() && chunks.iter().all(|c| c.get_row().in_memory()), + "expected in-memory chunks, got: {:?}", + chunks + ); + + let result = service + .exec_query("SELECT s, sum(n) FROM d.t GROUP BY 1 ORDER BY 1") + .await + .unwrap() + .collect() + .await + .unwrap(); + assert_eq!( + result.get_rows(), + &vec![ + Row::new(vec![ + TableValue::String("a".to_string()), + TableValue::Int(7) + ]), + Row::new(vec![ + TableValue::String("b".to_string()), + TableValue::Int(17) + ]), + Row::new(vec![ + TableValue::String("c".to_string()), + TableValue::Int(1) + ]), + ] + ); + Ok::<(), CubeError>(()) + }) + .await; + } + #[tokio::test] async fn compaction_wide_string_batches() { // Each chunk is read as a single sorted run whose batches keep their on-disk row-group