diff --git a/docs/measures-sql-paper-parity.md b/docs/measures-sql-paper-parity.md new file mode 100644 index 0000000..d27de94 --- /dev/null +++ b/docs/measures-sql-paper-parity.md @@ -0,0 +1,46 @@ +# Measures in SQL Paper Parity Matrix + +Source paper: `arXiv:2406.00251v2` ("Measures in SQL", Jan 10, 2025), local copy at `.context/measures_in_sql.txt`. + +## Scope + +This matrix tracks parity for the core language semantics described in sections 3-5 of the paper. + +- `Covered`: behavior is validated by existing automated tests. +- `Partial`: behavior is exercised, but not with an explicit parity assertion. +- `Gap`: no direct automated validation yet. + +## Matrix + +| Paper ref | Requirement | Status | Evidence | +|---|---|---|---| +| §3.2, Listing 3 | `AS MEASURE` in `CREATE VIEW`; no `GROUP BY` keeps base row cardinality | Covered | `test/sql/measures.test:20`, `test/sql/measures.test:1409` | +| §3.3 | `AGGREGATE(measure)` expansion/evaluation in grouped queries | Covered | `test/sql/measures.test:30`, `test/sql/measures.test:459` | +| Table 3 (`ALL`) | `AT (ALL)` removes all filters (grand total) | Covered | `test/sql/measures.test:152`, `test/sql/measures.test:444` | +| Table 3 (`ALL dim`) | `AT (ALL dim)` removes one dimension from context | Covered | `test/sql/measures.test:83`, `test/sql/measures.test:292` | +| Table 3 (`ALL dim1 dim2`) | single-clause multi-dimension `ALL` semantics | Covered | `test/sql/measures.test:1443` | +| Table 3 (modifier sequence) | chained modifiers execute right-to-left | Covered | `test/sql/measures.test:233`, `test/sql/measures.test:548` | +| Table 3 (`SET`) | `AT (SET dim = expr)` changes one dimension, correlates on others | Covered | `test/sql/measures.test:189`, `test/sql/measures.test:975` | +| Table 3 (`SET` + lost rows) | `SET` can reach rows removed by outer `WHERE` | Covered | `test/sql/measures.test:962` | +| Table 3 (`CURRENT`) | `CURRENT` resolves from single-valued context and returns `NULL` otherwise | Covered | `test/sql/measures.test:1665`, `test/sql/measures.test:1675` | +| Table 3 (`WHERE`) | `AT (WHERE predicate)` sets evaluation predicate | Covered | `test/sql/measures.test:165`, `test/sql/measures.test:341` | +| Table 3 (`WHERE`) | qualified refs and nested function predicates in `AT (WHERE ...)` | Covered | `test/sql/measures.test:178`, `test/sql/measures.test:1487` | +| Table 3 (`VISIBLE`) | `AT (VISIBLE)` respects current query visibility | Covered | `test/sql/measures.test:218`, `test/sql/measures.test:329` | +| §3.5 (ad hoc dims) | expression dimensions in `ALL`/`SET` (`MONTH(order_date)`, etc.) | Covered | `test/sql/measures.test:818`, `test/sql/measures.test:824`, `test/sql/measures.test:834` | +| Listing 8 | rollup query with `AGGREGATE`, plain measure ref, and `AT (VISIBLE)` | Covered | `test/sql/measures.test:1530` | +| Listing 9 | joins: weighted aggregate vs measure semantics vs `VISIBLE` | Covered | `test/sql/measures.test:1582`, `test/sql/measures.test:1458` | +| Listing 12 (queries 1-4) | correlated subquery, self-join, window, and measure forms return same rows | Covered | `test/sql/measures.test:1614`, `test/sql/measures.test:1624`, `test/sql/measures.test:1637`, `test/sql/measures.test:1652` | +| §5.1 claim | `AT` can access rows excluded by outer `WHERE` (more expressive than `OVER`) | Covered | `test/sql/measures.test:962` | +| §5.4 composability | derived measures referencing measures in same `SELECT` | Covered | `test/sql/measures.test:772`, `test/sql/measures.test:1499` | +| §5.3 wide-table safety direction | joins with measures avoid double counting in tested cases | Partial | `test/sql/measures.test:889`, `test/sql/measures.test:1473` | +| §5.5 security model | measure views preserve SQL security boundaries | Gap | no privilege-based test in suite | +| §3.4 call-site breadth | explicit use in `HAVING` parity path | Covered | `test/sql/measures.test:1548` | + +## Current Verdict + +- Core semantics used by the paper’s main language examples are covered, including listings `8`, `9`, and `12` (all four forms), `CURRENT`, rollup behavior, and modifier semantics. +- A strict "100% paper parity" claim is still not justified because of remaining `Gap` items above. + +## Minimal Remaining Work for a 100% Claim + +1. Add a security-behavior test plan (or explicit out-of-scope declaration if privileges are not testable in this harness). diff --git a/src/yardstick_extension.cpp b/src/yardstick_extension.cpp index b89b10d..059b69d 100644 --- a/src/yardstick_extension.cpp +++ b/src/yardstick_extension.cpp @@ -453,6 +453,52 @@ BoundStatement yardstick_bind(ClientContext &context, Binder &binder, } throw BinderException("Registered state not found"); } + + // Non-yardstick extension statements should not be rewritten by yardstick. + return {}; + } + case StatementType::SELECT_STATEMENT: { + auto sql_to_check = context.GetCurrentQuery(); + + if (yardstick_has_aggregate(sql_to_check.c_str())) { + YardstickAggregateResult result = yardstick_expand_aggregate(sql_to_check.c_str()); + if (result.error) { + string error_msg(result.error); + yardstick_free_aggregate_result(result); + throw BinderException("Failed to expand AGGREGATE: %s", error_msg); + } + + if (result.had_aggregate) { + string expanded_sql(result.expanded_sql); + yardstick_free_aggregate_result(result); + + // Escape single quotes for embedding in string literal + string escaped_sql; + for (char c : expanded_sql) { + if (c == '\'') { + escaped_sql += "''"; + } else { + escaped_sql += c; + } + } + + // Rebind through table function so rewritten SQL executes with normal planning + string wrapper_sql = "SELECT * FROM yardstick('" + escaped_sql + "')"; + Parser parser; + parser.ParseQuery(wrapper_sql); + auto statements = std::move(parser.statements); + + if (statements.empty()) { + throw BinderException("Table function wrapper produced no statements"); + } + + auto yardstick_binder = Binder::CreateBinder(context); + return yardstick_binder->Bind(*statements[0]); + } + + yardstick_free_aggregate_result(result); + } + return {}; } default: return {}; diff --git a/test/sql/measures.test b/test/sql/measures.test index 03728b6..a1c743e 100644 --- a/test/sql/measures.test +++ b/test/sql/measures.test @@ -79,6 +79,27 @@ GROUP BY year, region; 2023 EU 225.0 2023 US 225.0 +# Positional GROUP BY should resolve to SELECT dimensions (fallback from parser ordinals) +query IR rowsort +SEMANTIC SELECT year, AGGREGATE(revenue) +FROM sales_v +GROUP BY 1 +ORDER BY 1; +---- +2022 150.0 +2023 225.0 + +query IIR rowsort +SEMANTIC SELECT year, region, AGGREGATE(revenue) +FROM sales_v +GROUP BY 1, 2 +ORDER BY 1, 2; +---- +2022 EU 50.0 +2022 US 100.0 +2023 EU 75.0 +2023 US 150.0 + # ============================================================================= # Test: AT (ALL dimension) - remove dimension from context # ============================================================================= @@ -509,21 +530,21 @@ SEMANTIC SELECT year, region, AGGREGATE(min_sale), AGGREGATE(max_sale) FROM sale # ============================================================================= # Test: COUNT measure -# Note: COUNT(*) in a measure counts rows after the view's GROUP BY, not base table rows +# Note: COUNT(*) in a measure is evaluated against the base relation and re-aggregated # ============================================================================= query II rowsort SEMANTIC SELECT year, AGGREGATE(order_count) FROM orders_v; ---- -2022 1 -2023 1 +2022 2 +2023 2 # Count with AT (ALL) query II rowsort SEMANTIC SELECT year, AGGREGATE(order_count) AT (ALL) FROM orders_v; ---- -2022 2 -2023 2 +2022 4 +2023 4 # ============================================================================= # Test: Combining base AGGREGATE with multiple AT variants @@ -1401,3 +1422,281 @@ query I SEMANTIC SELECT AGGREGATE(revenue) FROM sales_nulls_v; ---- 390.0 + +# ============================================================================= +# Paper regression matrix (consolidated) +# ============================================================================= + +# Case 1: AS MEASURE view cardinality should match base relation +statement ok +CREATE TABLE orders_cardinality (order_date DATE, prod TEXT, revenue INT, cost INT); + +statement ok +INSERT INTO orders_cardinality VALUES + ('2024-01-01', 'A', 100, 40), + ('2024-01-01', 'A', 50, 10); + +statement ok +CREATE VIEW orders_cardinality_v AS +SELECT + order_date, + prod, + (SUM(revenue) - SUM(cost))::DOUBLE / SUM(revenue) AS MEASURE profit_margin +FROM orders_cardinality; + +query II +SELECT + (SELECT COUNT(*) FROM orders_cardinality) AS base_rows, + (SELECT COUNT(*) FROM orders_cardinality_v) AS view_rows; +---- +2 2 + +# Case 2: implicit measure reference in grouped SELECT +query IR rowsort +SELECT year, revenue +FROM sales_v +GROUP BY year +ORDER BY year; +---- +2022 150.0 +2023 225.0 + +# Case 3: AT (ALL d1 d2) in a single modifier clause +query IIRR rowsort +SELECT + year, + region, + AGGREGATE(revenue) AT (ALL year region) AS single_all, + AGGREGATE(revenue) AT (ALL year) AT (ALL region) AS chained_all +FROM sales_v +ORDER BY year, region; +---- +2022 EU 375.0 375.0 +2022 US 375.0 375.0 +2023 EU 375.0 375.0 +2023 US 375.0 375.0 + +# Case 4: qualified measure in join +statement ok +CREATE TABLE customers_qualified (cust_id INT, age INT); + +statement ok +INSERT INTO customers_qualified VALUES (1, 20), (2, 40); + +statement ok +CREATE TABLE orders_qualified (cust_id INT, product TEXT); + +statement ok +INSERT INTO orders_qualified VALUES (1, 'X'), (1, 'X'), (2, 'X'); + +statement ok +CREATE VIEW customers_qualified_v AS +SELECT cust_id, AVG(age) AS MEASURE avg_age +FROM customers_qualified; + +query TRR +SELECT + o.product, + AGGREGATE(c.avg_age) AS measure_avg, + AGGREGATE(c.avg_age) AT (VISIBLE) AS visible_avg +FROM orders_qualified o +JOIN customers_qualified_v c ON o.cust_id = c.cust_id +GROUP BY o.product; +---- +X 30.0 30.0 + +# Case 5: nested function parentheses in AT (WHERE ...) +query IR rowsort +SELECT + YEAR(sale_date) AS y, + AGGREGATE(revenue) AT (WHERE YEAR(sale_date) = 2023) AS y2023 +FROM dated_sales_v +GROUP BY YEAR(sale_date) +ORDER BY y; +---- +2023 530.0 +2024 530.0 + +# Case 6: derived measure with AT (SET ...) +query IRR rowsort +SEMANTIC SELECT + year, + AGGREGATE(profit) AS current_profit, + AGGREGATE(profit) AT (SET year = year - 1) AS prior_profit +FROM financials_v +ORDER BY year; +---- +2022 110.0 NULL +2023 230.0 110.0 + +# ============================================================================= +# Paper parity: listings 8, 9, 12 and CURRENT null semantics +# ============================================================================= + +statement ok +CREATE TABLE paper_orders (prodName TEXT, custName TEXT, order_date DATE, revenue INT); + +statement ok +INSERT INTO paper_orders VALUES + ('Happy', 'Var Bob', '2024-01-01', 4), + ('Happy', 'Alice', '2024-01-02', 6), + ('Happy', 'Alice', '2024-01-03', 7), + ('Whizz', 'Alice', '2024-01-04', 3); + +statement ok +CREATE VIEW paper_orders_v AS +SELECT *, SUM(revenue) AS MEASURE sumRevenue +FROM paper_orders; + +# Listing 8 style: AGGREGATE() is visible, plain measure ref ignores WHERE, +# and measure AT(VISIBLE) is accepted without AGGREGATE wrapper. +query TIRRR rowsort +SELECT + o.prodName, + COUNT(*) AS c, + AGGREGATE(o.sumRevenue) AS rAgg, + o.sumRevenue AT (VISIBLE) AS rViz, + o.sumRevenue AS r +FROM paper_orders_v o +WHERE o.custName <> 'Var Bob' +GROUP BY ROLLUP(o.prodName) +ORDER BY o.prodName; +---- +Happy 2 13 13 17 +NULL 3 NULL NULL NULL +Whizz 1 3 3 3 + +# Measures/CSEs should be valid in HAVING, and respect context semantics. +query TI rowsort +SELECT + o.prodName, + AGGREGATE(o.sumRevenue) AT (VISIBLE) AS rViz +FROM paper_orders_v o +WHERE o.custName <> 'Var Bob' +GROUP BY o.prodName +HAVING AGGREGATE(o.sumRevenue) AT (VISIBLE) > 10 +ORDER BY o.prodName; +---- +Happy 13 + +statement ok +CREATE TABLE paper_customers (custName TEXT, custAge INT); + +statement ok +INSERT INTO paper_customers VALUES + ('Alice', 30), ('Var Bob', 16), ('Carol', 40); + +statement ok +CREATE TABLE paper_order_customers (prodName TEXT, custName TEXT); + +statement ok +INSERT INTO paper_order_customers VALUES + ('Happy', 'Alice'), + ('Happy', 'Var Bob'), + ('Whizz', 'Carol'); + +statement ok +CREATE VIEW enhanced_customers_paper AS +SELECT *, AVG(custAge) AS MEASURE avgAge +FROM paper_customers; + +# Listing 9 style: weighted average vs unweighted avgAge vs visibleAvgAge. +query TRRRR rowsort +SELECT + o.prodName, + COUNT(*) AS orderCount, + AVG(c.custAge) AS weightedAvgAge, + c.avgAge AS avgAge, + c.avgAge AT (VISIBLE) AS visibleAvgAge +FROM paper_order_customers o +JOIN enhanced_customers_paper c USING (custName) +WHERE c.custAge >= 18 +GROUP BY o.prodName +ORDER BY o.prodName; +---- +Happy 1 30.0 28.666666666666668 35.0 +Whizz 1 40.0 28.666666666666668 35.0 + +statement ok +CREATE TABLE paper_orders_l12 (prodName TEXT, orderDate DATE, revenue INT); + +statement ok +INSERT INTO paper_orders_l12 VALUES + ('Happy', '2024-01-01', 4), + ('Happy', '2024-01-02', 6), + ('Happy', '2024-01-03', 7), + ('Whizz', '2024-01-04', 3); + +statement ok +CREATE VIEW paper_orders_l12_v AS +SELECT prodName, orderDate, revenue, AVG(revenue) AS MEASURE avgRevenue +FROM paper_orders_l12; + +# Listing 12 style measure syntax without AGGREGATE wrapper. +query TT rowsort +SELECT o.prodName, o.orderDate +FROM paper_orders_l12_v o +WHERE o.revenue > o.avgRevenue AT (WHERE prodName = o.prodName) +ORDER BY o.prodName, o.orderDate; +---- +Happy 2024-01-02 +Happy 2024-01-03 + +# Listing 12 query 1: correlated subquery. +query TT rowsort +SELECT o.prodName, o.orderDate +FROM paper_orders_l12 o +WHERE o.revenue > + (SELECT AVG(revenue) + FROM paper_orders_l12 o1 + WHERE o1.prodName = o.prodName) +ORDER BY o.prodName, o.orderDate; +---- +Happy 2024-01-02 +Happy 2024-01-03 + +# Listing 12 query 2: self-join. +query TT rowsort +SELECT o.prodName, o.orderDate +FROM paper_orders_l12 o +LEFT JOIN + (SELECT prodName, AVG(revenue) AS avgRevenue + FROM paper_orders_l12 + GROUP BY prodName) o2 +ON o.prodName = o2.prodName +WHERE o.revenue > o2.avgRevenue +ORDER BY o.prodName, o.orderDate; +---- +Happy 2024-01-02 +Happy 2024-01-03 + +# Listing 12 query 3: window aggregate. +query TT rowsort +SELECT o.prodName, o.orderDate +FROM + (SELECT prodName, revenue, orderDate, + AVG(revenue) OVER (PARTITION BY prodName) AS avgRevenue + FROM paper_orders_l12) o +WHERE o.revenue > o.avgRevenue +ORDER BY o.prodName, o.orderDate; +---- +Happy 2024-01-02 +Happy 2024-01-03 + +# CURRENT should evaluate to NULL when dimension is not single-valued in context. +query TR rowsort +SEMANTIC SELECT region, AGGREGATE(revenue) AT (SET year = CURRENT year - 1) AS prior_from_current +FROM sales_v +GROUP BY region +ORDER BY region; +---- +EU NULL +US NULL + +# CURRENT can resolve from a single-valued WHERE context. +query R +SEMANTIC SELECT AGGREGATE(revenue) AT (SET year = CURRENT year - 1) +FROM sales_v +WHERE year = 2023; +---- +150.0 diff --git a/yardstick-rs/src/ffi.rs b/yardstick-rs/src/ffi.rs index 9ecaafe..389491d 100644 --- a/yardstick-rs/src/ffi.rs +++ b/yardstick-rs/src/ffi.rs @@ -14,6 +14,7 @@ use crate::sql::{ drop_measure_view_from_sql, expand_aggregate_with_at, expand_curly_braces, get_measure_aggregation, has_aggregate_function, has_as_measure, has_at_syntax, has_curly_brace_measure, process_create_view, + has_implicit_measure_refs, has_measure_at_refs, }; /// Result from processing CREATE VIEW with AS MEASURE @@ -57,7 +58,7 @@ pub extern "C" fn yardstick_has_as_measure(sql: *const c_char) -> bool { has_as_measure(sql_str) } -/// Check if SQL contains AGGREGATE() function +/// Check if SQL contains AGGREGATE() or implicit measure references #[no_mangle] pub extern "C" fn yardstick_has_aggregate(sql: *const c_char) -> bool { if sql.is_null() { @@ -72,6 +73,8 @@ pub extern "C" fn yardstick_has_aggregate(sql: *const c_char) -> bool { }; has_aggregate_function(sql_str) + || has_implicit_measure_refs(sql_str) + || has_measure_at_refs(sql_str) } /// Drop a measure view from the catalog if the SQL is a DROP VIEW statement diff --git a/yardstick-rs/src/sql/measures.rs b/yardstick-rs/src/sql/measures.rs index 4f30894..0a874ec 100644 --- a/yardstick-rs/src/sql/measures.rs +++ b/yardstick-rs/src/sql/measures.rs @@ -56,6 +56,8 @@ pub struct MeasureView { static MEASURE_VIEWS: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); +const DEFAULT_CONTEXT_MARKER: &str = "/*YARDSTICK_DEFAULT*/"; + /// Result from processing CREATE VIEW with AS MEASURE #[derive(Debug)] pub struct CreateViewResult { @@ -337,6 +339,557 @@ pub fn has_aggregate_function(sql: &str) -> bool { false } +fn normalize_identifier_name(name: &str) -> String { + name.trim() + .trim_matches('"') + .trim_matches('`') + .trim_matches('[') + .trim_matches(']') + .to_ascii_lowercase() +} + +fn parse_simple_measure_ref(expr: &str) -> Option<(Option, String)> { + let trimmed = expr.trim(); + if trimmed.is_empty() { + return None; + } + + let allowed = |c: char| c.is_ascii_alphanumeric() || c == '_' || c == '.' || c == '"' || c == '`' || c == '[' || c == ']'; + if !trimmed.chars().all(allowed) { + return None; + } + + let parts: Vec<&str> = trimmed.split('.').collect(); + match parts.as_slice() { + [measure] => Some((None, normalize_identifier_name(measure))), + [qualifier, measure] => Some(( + Some(normalize_identifier_name(qualifier)), + normalize_identifier_name(measure), + )), + _ => None, + } +} + +fn measure_columns_for_query_tables( + info: &SelectInfo, +) -> ( + HashMap>, + HashSet, +) { + let views = MEASURE_VIEWS.lock().unwrap(); + let mut by_qualifier: HashMap> = HashMap::new(); + let mut any_measure: HashSet = HashSet::new(); + + for table in &info.tables { + let maybe_view = views.iter().find(|(name, _)| name.eq_ignore_ascii_case(&table.table_name)); + let Some((_, view)) = maybe_view else { + continue; + }; + + let measures: HashSet = view + .measures + .iter() + .map(|m| normalize_identifier_name(&m.column_name)) + .collect(); + if measures.is_empty() { + continue; + } + + any_measure.extend(measures.iter().cloned()); + + by_qualifier + .entry(normalize_identifier_name(&table.table_name)) + .or_default() + .extend(measures.iter().cloned()); + + if let Some(alias) = &table.alias { + by_qualifier + .entry(normalize_identifier_name(alias)) + .or_default() + .extend(measures.iter().cloned()); + } + } + + (by_qualifier, any_measure) +} + +fn select_item_is_implicit_measure_ref( + item_expr: &str, + by_qualifier: &HashMap>, + any_measure: &HashSet, +) -> bool { + let Some((qualifier, measure)) = parse_simple_measure_ref(item_expr) else { + return false; + }; + + if let Some(q) = qualifier { + return by_qualifier + .get(&q) + .map(|names| names.contains(&measure)) + .unwrap_or(false); + } + + any_measure.contains(&measure) +} + +pub fn has_implicit_measure_refs(sql: &str) -> bool { + let known_measures = known_measure_names(); + if known_measures.is_empty() { + return false; + } + + let info = match parser_ffi::parse_select(sql) { + Ok(info) => info, + Err(_) => return has_implicit_measure_refs_fallback(sql, &known_measures), + }; + + let (by_qualifier, any_measure) = measure_columns_for_query_tables(&info); + if any_measure.is_empty() { + return false; + } + + info.items + .iter() + .any(|item| { + !item.is_aggregate + && !item.is_star + && !item.is_measure_ref + && select_item_is_implicit_measure_ref(&item.expression_sql, &by_qualifier, &any_measure) + }) +} + +fn collect_top_level_select_item_ranges(sql: &str) -> Option> { + let query = sql; + let select_pos = find_top_level_keyword(query, "SELECT", 0)?; + let from_pos = find_top_level_keyword(query, "FROM", select_pos)?; + let select_start = select_pos + "SELECT".len(); + if select_start >= from_pos { + return None; + } + + let bytes = query.as_bytes(); + let mut ranges = Vec::new(); + let mut item_start = select_start; + let mut depth = 0i32; + let mut i = select_start; + let mut in_single = false; + let mut in_double = false; + + while i < from_pos { + let b = bytes[i]; + if in_single { + if b == b'\'' { + if i + 1 < from_pos && bytes[i + 1] == b'\'' { + i += 2; + continue; + } + in_single = false; + } + i += 1; + continue; + } + if in_double { + if b == b'"' { + if i + 1 < from_pos && bytes[i + 1] == b'"' { + i += 2; + continue; + } + in_double = false; + } + i += 1; + continue; + } + + match b { + b'\'' => in_single = true, + b'"' => in_double = true, + b'(' => depth += 1, + b')' => { + if depth > 0 { + depth -= 1; + } + } + b',' if depth == 0 => { + ranges.push((item_start, i)); + item_start = i + 1; + } + _ => {} + } + + i += 1; + } + + if item_start < from_pos { + ranges.push((item_start, from_pos)); + } + + Some(ranges) +} + +fn find_top_level_as_pos(item: &str) -> Option { + let upper = item.to_uppercase(); + let bytes = item.as_bytes(); + let upper_bytes = upper.as_bytes(); + let mut depth = 0i32; + let mut i = 0usize; + let mut in_single = false; + let mut in_double = false; + + while i + 4 <= bytes.len() { + let b = bytes[i]; + if in_single { + if b == b'\'' { + if i + 1 < bytes.len() && bytes[i + 1] == b'\'' { + i += 2; + continue; + } + in_single = false; + } + i += 1; + continue; + } + if in_double { + if b == b'"' { + if i + 1 < bytes.len() && bytes[i + 1] == b'"' { + i += 2; + continue; + } + in_double = false; + } + i += 1; + continue; + } + + match b { + b'\'' => in_single = true, + b'"' => in_double = true, + b'(' => depth += 1, + b')' => { + if depth > 0 { + depth -= 1; + } + } + _ => {} + } + + if depth == 0 && upper_bytes[i..].starts_with(b" AS ") { + return Some(i); + } + i += 1; + } + + None +} + +fn split_item_expr_and_alias(item: &str) -> (&str, Option<&str>) { + if let Some(pos) = find_top_level_as_pos(item) { + let expr = item[..pos].trim(); + let alias = item[pos..].trim(); + (expr, Some(alias)) + } else { + (item.trim(), None) + } +} + +fn is_simple_known_measure_ref(expr: &str, known_measures: &HashSet) -> bool { + if has_aggregate_function(expr) { + return false; + } + if expr.to_uppercase().contains(" AT ") { + return false; + } + let Some((_, measure_name)) = parse_simple_measure_ref(expr) else { + return false; + }; + known_measures.contains(&measure_name) +} + +fn starts_with_top_level_select_or_with(sql: &str) -> bool { + let trimmed = sql.trim_start(); + let upper = trimmed.to_uppercase(); + upper.starts_with("SELECT ") || upper.starts_with("SELECT\n") || upper.starts_with("WITH ") +} + +fn has_implicit_measure_refs_fallback(sql: &str, known_measures: &HashSet) -> bool { + if !starts_with_top_level_select_or_with(sql) { + return false; + } + let Some(ranges) = collect_top_level_select_item_ranges(sql) else { + return false; + }; + ranges.into_iter().any(|(start, end)| { + let item = sql[start..end].trim(); + if item.is_empty() { + return false; + } + let (expr, _) = split_item_expr_and_alias(item); + is_simple_known_measure_ref(expr, known_measures) + }) +} + +fn rewrite_implicit_measure_refs_fallback(sql: &str, known_measures: &HashSet) -> String { + if !starts_with_top_level_select_or_with(sql) { + return sql.to_string(); + } + let Some(ranges) = collect_top_level_select_item_ranges(sql) else { + return sql.to_string(); + }; + + let mut replacements: Vec<(usize, usize, String)> = Vec::new(); + for (start, end) in ranges { + let raw_item = &sql[start..end]; + let item = raw_item.trim(); + if item.is_empty() { + continue; + } + + let (expr, alias_clause) = split_item_expr_and_alias(item); + if !is_simple_known_measure_ref(expr, known_measures) { + continue; + } + + let mut rewritten = format!("AGGREGATE({}) {}", expr.trim(), DEFAULT_CONTEXT_MARKER); + if let Some(alias) = alias_clause { + rewritten.push(' '); + rewritten.push_str(alias); + } + // Preserve token separation before the next clause keyword (e.g. FROM). + rewritten.push(' '); + replacements.push((start, end, rewritten)); + } + + if replacements.is_empty() { + return sql.to_string(); + } + + let mut result = sql.to_string(); + replacements.sort_by(|a, b| b.0.cmp(&a.0)); + for (start, end, rewritten) in replacements { + if start <= end && end <= result.len() { + result.replace_range(start..end, &rewritten); + } + } + result +} + +fn rewrite_implicit_measure_refs(sql: &str) -> String { + let known_measures = known_measure_names(); + if known_measures.is_empty() { + return sql.to_string(); + } + + let info = match parser_ffi::parse_select(sql) { + Ok(info) => info, + Err(_) => return rewrite_implicit_measure_refs_fallback(sql, &known_measures), + }; + + let (by_qualifier, any_measure) = measure_columns_for_query_tables(&info); + if any_measure.is_empty() { + return rewrite_implicit_measure_refs_fallback(sql, &known_measures); + } + + let mut replacements: Vec = Vec::new(); + for item in &info.items { + if item.is_aggregate || item.is_star || item.is_measure_ref { + continue; + } + if !select_item_is_implicit_measure_ref(&item.expression_sql, &by_qualifier, &any_measure) { + continue; + } + + // Plain measure refs follow paper default semantics (not VISIBLE). + // Keep an internal marker so expansion can apply default context rules. + let mut rewritten = format!( + "AGGREGATE({}) {}", + item.expression_sql.trim(), + DEFAULT_CONTEXT_MARKER + ); + if let Some(alias) = &item.alias { + rewritten.push_str(" AS "); + rewritten.push_str(alias); + } + // Ensure separator before the next token (e.g. FROM) when replacing final SELECT item. + rewritten.push(' '); + + replacements.push(parser_ffi::Replacement { + start_pos: item.start_pos, + end_pos: item.end_pos, + replacement: rewritten, + }); + } + + if replacements.is_empty() { + return sql.to_string(); + } + + parser_ffi::apply_replacements(sql, &replacements).unwrap_or_else(|_| sql.to_string()) +} + +fn known_measure_names() -> HashSet { + let views = MEASURE_VIEWS.lock().unwrap(); + views + .values() + .flat_map(|view| view.measures.iter().map(|m| normalize_identifier_name(&m.column_name))) + .collect() +} + +fn is_measure_ref_char(c: u8) -> bool { + c.is_ascii_alphanumeric() || matches!(c, b'_' | b'.' | b'"' | b'`' | b'[' | b']') +} + +fn is_keyword_boundary(bytes: &[u8], start: usize, end: usize) -> bool { + let left_ok = start == 0 || !(bytes[start - 1].is_ascii_alphanumeric() || bytes[start - 1] == b'_'); + let right_ok = end >= bytes.len() || !(bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_'); + left_ok && right_ok +} + +fn find_previous_measure_ref_bounds(sql: &str, at_start: usize) -> Option<(usize, usize)> { + let bytes = sql.as_bytes(); + let mut end = at_start; + while end > 0 && bytes[end - 1].is_ascii_whitespace() { + end -= 1; + } + if end == 0 || bytes[end - 1] == b')' { + return None; + } + + let mut start = end; + while start > 0 && is_measure_ref_char(bytes[start - 1]) { + start -= 1; + } + if start == end { + return None; + } + + if start > 0 { + let prev = bytes[start - 1]; + if is_measure_ref_char(prev) || prev == b')' { + return None; + } + } + + Some((start, end)) +} + +fn rewrite_measure_at_refs(sql: &str) -> String { + let known_measures = known_measure_names(); + if known_measures.is_empty() { + return sql.to_string(); + } + + let bytes = sql.as_bytes(); + let upper = sql.to_uppercase(); + let upper_bytes = upper.as_bytes(); + let mut replacements: Vec = Vec::new(); + let mut i = 0usize; + + while i < bytes.len() { + match bytes[i] { + b'\'' => { + i += 1; + while i < bytes.len() { + if bytes[i] == b'\'' { + if i + 1 < bytes.len() && bytes[i + 1] == b'\'' { + i += 2; + continue; + } + i += 1; + break; + } + i += 1; + } + continue; + } + b'"' => { + i += 1; + while i < bytes.len() { + if bytes[i] == b'"' { + if i + 1 < bytes.len() && bytes[i + 1] == b'"' { + i += 2; + continue; + } + i += 1; + break; + } + i += 1; + } + continue; + } + b'-' if i + 1 < bytes.len() && bytes[i + 1] == b'-' => { + i += 2; + while i < bytes.len() && bytes[i] != b'\n' { + i += 1; + } + continue; + } + b'/' if i + 1 < bytes.len() && bytes[i + 1] == b'*' => { + i += 2; + while i + 1 < bytes.len() { + if bytes[i] == b'*' && bytes[i + 1] == b'/' { + i += 2; + break; + } + i += 1; + } + continue; + } + _ => {} + } + + if i + 2 <= bytes.len() + && upper_bytes[i..].starts_with(b"AT") + && is_keyword_boundary(bytes, i, i + 2) + { + let mut j = i + 2; + while j < bytes.len() && bytes[j].is_ascii_whitespace() { + j += 1; + } + if j < bytes.len() && bytes[j] == b'(' { + if let Some((start, end)) = find_previous_measure_ref_bounds(sql, i) { + let token = sql[start..end].trim(); + if let Some((_, measure_name)) = parse_simple_measure_ref(token) { + if known_measures.contains(&measure_name) { + replacements.push(parser_ffi::Replacement { + start_pos: start as u32, + end_pos: end as u32, + replacement: format!("AGGREGATE({token})"), + }); + } + } + } + } + i += 2; + continue; + } + + i += 1; + } + + if replacements.is_empty() { + return sql.to_string(); + } + + parser_ffi::apply_replacements(sql, &replacements).unwrap_or_else(|_| sql.to_string()) +} + +pub fn has_measure_at_refs(sql: &str) -> bool { + rewrite_measure_at_refs(sql) != sql +} + +fn strip_measure_qualifier(measure: &str) -> String { + measure + .trim() + .split('.') + .next_back() + .unwrap_or(measure.trim()) + .trim() + .trim_matches('"') + .trim_matches('`') + .trim_matches('[') + .trim_matches(']') + .to_string() +} + /// Check if SQL contains curly brace measure syntax: `{column}` pub fn has_curly_brace_measure(sql: &str) -> bool { let mut parser = delimited( @@ -422,29 +975,9 @@ fn at_visible(input: &str) -> IResult<&str, ContextModifier> { } /// Parse CURRENT dimension reference in expressions -/// Returns the dimension name with CURRENT stripped +/// CURRENT is resolved later once GROUP BY/WHERE context is known. fn parse_current_in_expr(expr: &str) -> String { - // Replace "CURRENT dim" with just "dim" (let outer reference handle it) - let mut result = expr.to_string(); - let expr_upper = expr.to_uppercase(); - - // Find all "CURRENT identifier" patterns - let mut search_pos = 0; - while let Some(pos) = expr_upper[search_pos..].find("CURRENT ") { - let abs_pos = search_pos + pos; - let after_current = abs_pos + 8; // "CURRENT " is 8 chars - - // Extract the identifier after CURRENT - let remaining = &expr[after_current..]; - if let Ok((_, ident)) = identifier(remaining) { - // Replace "CURRENT ident" with just "ident" - let pattern = &expr[abs_pos..after_current + ident.len()]; - result = result.replacen(pattern, ident, 1); - } - search_pos = abs_pos + 1; - } - - result + expr.trim().to_string() } /// Parse AT (SET dimension = expr) @@ -467,10 +1000,11 @@ fn at_set(input: &str) -> IResult<&str, ContextModifier> { fn at_where(input: &str) -> IResult<&str, ContextModifier> { let (input, _) = tag_no_case("WHERE")(input)?; let (input, _) = multispace1(input)?; - // Take rest as condition (until closing paren, handled by caller) - let (input, cond) = take_while(|c: char| c != ')')(input)?; + // AT (...) content is already balanced by the caller, so WHERE can + // safely consume the remainder here (including nested function parens). + let cond = input.trim(); let stripped = strip_at_where_qualifiers(cond.trim()); - Ok((input, ContextModifier::Where(stripped))) + Ok(("", ContextModifier::Where(stripped))) } /// Parse any AT modifier content @@ -523,6 +1057,17 @@ pub struct AggregateCall { /// Parse multiple modifiers inside a single AT clause (e.g., AT (SET year = year - 1 VISIBLE)) fn at_modifiers_content(input: &str) -> IResult<&str, Vec> { + fn starts_with_modifier_keyword(s: &str) -> bool { + let trimmed = s.trim_start(); + let upper = trimmed.to_uppercase(); + upper.starts_with("ALL ") + || upper == "ALL" + || upper.starts_with("SET ") + || upper.starts_with("WHERE ") + || upper == "VISIBLE" + || upper.starts_with("VISIBLE ") + } + let mut modifiers = Vec::new(); let mut remaining = input.trim(); @@ -533,6 +1078,17 @@ fn at_modifiers_content(input: &str) -> IResult<&str, Vec> { Ok((rest, modifier)) => { modifiers.push(modifier); remaining = rest.trim(); + + if let Some(ContextModifier::All(_)) = modifiers.last() { + while !remaining.is_empty() && !starts_with_modifier_keyword(remaining) { + if let Ok((after_dim, dim)) = expression_or_identifier(remaining) { + modifiers.push(ContextModifier::All(dim)); + remaining = after_dim.trim(); + } else { + break; + } + } + } } Err(_) => break, } @@ -1536,7 +2092,21 @@ fn extract_view_group_by_cols(view_query: &str) -> Vec { let query = view_query.trim().trim_end_matches(';').trim(); let group_pos = match find_top_level_keyword(query, "GROUP BY", 0) { Some(pos) => pos, - None => return Vec::new(), + None => { + if let Ok(info) = parser_ffi::parse_select(query) { + return info + .items + .iter() + .filter(|item| !item.is_aggregate && !item.is_star && !item.is_measure_ref) + .map(|item| { + item.alias + .clone() + .unwrap_or_else(|| item.expression_sql.clone()) + }) + .collect(); + } + return Vec::new(); + } }; let start = advance_after_group_by(query, group_pos) @@ -1684,6 +2254,25 @@ fn filter_group_by_cols_for_measure( .collect() } +fn source_dimension_names(source_view: &str) -> HashSet { + let views = MEASURE_VIEWS.lock().unwrap(); + let Some((_, view)) = views + .iter() + .find(|(name, _)| name.eq_ignore_ascii_case(source_view)) + else { + return HashSet::new(); + }; + + let view_query = extract_view_query(&view.base_query).unwrap_or_else(|| view.base_query.clone()); + extract_dimension_columns_from_select(&view_query) + .into_iter() + .map(|col| { + let dim_name = col.split('.').next_back().unwrap_or(col.as_str()).trim(); + normalize_dimension_key(dim_name) + }) + .collect() +} + fn can_use_view_measure_directly(resolved: &ResolvedMeasure, outer_group_by: &[String]) -> bool { group_by_matches_view(outer_group_by, &resolved.view_group_by_cols) } @@ -1916,14 +2505,9 @@ fn expand_derived_measure_expr(expr: &str, measure_view: &MeasureView) -> String .iter() .find(|m| m.column_name.eq_ignore_ascii_case(&ident)) { - // Get the aggregation function from this measure's expression - if let Some(agg_fn) = extract_aggregation_function(&m.expression) { - // Replace measure name with AGG(measure_name) - result.push_str(&format!("{}({})", agg_fn.to_uppercase(), ident)); - } else { - // Fallback to SUM if no aggregation found - result.push_str(&format!("SUM({ident})")); - } + result.push('('); + result.push_str(&m.expression); + result.push(')'); } else { // Not a measure, keep as-is result.push_str(&ident); @@ -2044,6 +2628,220 @@ fn dimension_in_group_by( }) } +fn expr_mentions_identifier_outside_current(expr: &str, ident: &str) -> bool { + let mut chars = expr.chars().peekable(); + let mut pending_current = false; + + while let Some(c) = chars.next() { + if c == '\'' { + while let Some(next) = chars.next() { + if next == '\'' { + if chars.peek() == Some(&'\'') { + chars.next(); + } else { + break; + } + } + } + continue; + } + + if c.is_alphabetic() || c == '_' { + let mut token = String::from(c); + while let Some(&next) = chars.peek() { + if next.is_alphanumeric() || next == '_' { + token.push(chars.next().unwrap()); + } else { + break; + } + } + + if token.eq_ignore_ascii_case("CURRENT") { + pending_current = true; + continue; + } + + if token.eq_ignore_ascii_case(ident) { + if pending_current { + pending_current = false; + } else { + return true; + } + } else { + pending_current = false; + } + } + } + + false +} + +fn where_has_simple_equality_constraint( + where_clause: &str, + dim_name: &str, + default_qualifier: Option<&str>, +) -> bool { + // Conservative: if OR is present we don't claim single-valued context. + if where_clause.to_uppercase().contains(" OR ") { + return false; + } + + let expected_qualifier = default_qualifier.map(normalize_identifier_name); + let bytes = where_clause.as_bytes(); + let mut i = 0usize; + + while i < bytes.len() { + if bytes[i] == b'=' { + if (i > 0 && matches!(bytes[i - 1], b'<' | b'>' | b'!' | b'=')) + || (i + 1 < bytes.len() && bytes[i + 1] == b'=') + { + i += 1; + continue; + } + + let mut end = i; + while end > 0 && bytes[end - 1].is_ascii_whitespace() { + end -= 1; + } + let mut start = end; + while start > 0 && is_measure_ref_char(bytes[start - 1]) { + start -= 1; + } + + if start < end { + let left = where_clause[start..end].trim(); + if let Some((qualifier, name)) = parse_simple_measure_ref(left) { + if name.eq_ignore_ascii_case(dim_name) { + let qualifier_ok = match (&expected_qualifier, qualifier) { + (Some(expected), Some(found)) => found.eq_ignore_ascii_case(expected), + _ => true, + }; + if qualifier_ok { + return true; + } + } + } + } + } + i += 1; + } + + false +} + +fn current_dimension_is_single_valued( + dim: &str, + group_by_cols: &[String], + outer_where: Option<&str>, + default_qualifier: Option<&str>, +) -> bool { + if dimension_in_group_by(dim, group_by_cols, default_qualifier) { + return true; + } + + let dim_name = dim.split('.').next_back().unwrap_or(dim).trim(); + outer_where + .map(|w| where_has_simple_equality_constraint(w, dim_name, default_qualifier)) + .unwrap_or(false) +} + +fn resolve_current_in_expr( + expr: &str, + group_by_cols: &[String], + outer_where: Option<&str>, + default_qualifier: Option<&str>, +) -> String { + let bytes = expr.as_bytes(); + let mut out = String::new(); + let mut i = 0usize; + + while i < bytes.len() { + match bytes[i] { + b'\'' => { + out.push('\''); + i += 1; + while i < bytes.len() { + out.push(bytes[i] as char); + if bytes[i] == b'\'' { + if i + 1 < bytes.len() && bytes[i + 1] == b'\'' { + i += 1; + out.push(bytes[i] as char); + } else { + i += 1; + break; + } + } + i += 1; + } + } + c if (c as char).is_alphabetic() || c == b'_' => { + let token_start = i; + i += 1; + while i < bytes.len() && ((bytes[i] as char).is_alphanumeric() || bytes[i] == b'_') { + i += 1; + } + let token = &expr[token_start..i]; + + if token.eq_ignore_ascii_case("CURRENT") { + let mut j = i; + while j < bytes.len() && bytes[j].is_ascii_whitespace() { + j += 1; + } + + if j < bytes.len() + && (((bytes[j] as char).is_alphabetic()) || bytes[j] == b'_') + { + let dim_start = j; + j += 1; + while j < bytes.len() + && ((bytes[j] as char).is_alphanumeric() || bytes[j] == b'_') + { + j += 1; + } + while j < bytes.len() && bytes[j] == b'.' { + let mut k = j + 1; + if k >= bytes.len() + || !(((bytes[k] as char).is_alphabetic()) || bytes[k] == b'_') + { + break; + } + k += 1; + while k < bytes.len() + && ((bytes[k] as char).is_alphanumeric() || bytes[k] == b'_') + { + k += 1; + } + j = k; + } + + let dim = expr[dim_start..j].trim(); + if current_dimension_is_single_valued( + dim, + group_by_cols, + outer_where, + default_qualifier, + ) { + out.push_str(dim); + } else { + out.push_str("NULL"); + } + i = j; + continue; + } + } + + out.push_str(token); + } + _ => { + out.push(bytes[i] as char); + i += 1; + } + } + } + + out +} + /// Qualify column references in a WHERE clause for use inside _inner subquery /// "region = 'US'" -> "_inner.region = 'US'" /// "year > 2020 AND region = 'US'" -> "_inner.year > 2020 AND _inner.region = 'US'" @@ -2475,7 +3273,12 @@ pub fn process_create_view(sql: &str) -> CreateViewResult { extract_view_query(&clean_sql).unwrap_or_else(|| clean_sql.clone()); let base_relation_sql = extract_base_relation_sql(&view_query); let dimension_exprs = extract_dimension_exprs_from_query(&view_query); - let group_by_cols = extract_view_group_by_cols(&view_query); + let mut group_by_cols = extract_view_group_by_cols(&view_query); + let measure_col_names: HashSet = measures + .iter() + .map(|m| normalize_group_by_col(&m.column_name)) + .collect(); + group_by_cols.retain(|col| !measure_col_names.contains(&normalize_group_by_col(col))); let measure_view = MeasureView { view_name: vn.clone(), measures: measures.clone(), @@ -2550,67 +3353,26 @@ fn extract_measures_from_sql( } } - // Collect all measure names for derived measure detection - let measure_names: Vec<&str> = measure_infos.iter().map(|m| m.name.as_str()).collect(); - - // Check if an expression is derived (references other measures, no aggregation) - let is_derived = |expr: &str| -> bool { - if extract_aggregation_function(expr).is_some() { - return false; // Has aggregation, not derived - } - // Check if expression references any measure name - for name in &measure_names { - // Simple word boundary check - let expr_lower = expr.to_lowercase(); - let name_lower = name.to_lowercase(); - if expr_lower.contains(&name_lower) { - // More precise check: ensure it's a word boundary - for (i, _) in expr_lower.match_indices(&name_lower) { - let before_ok = - i == 0 || !expr.chars().nth(i - 1).unwrap_or(' ').is_alphanumeric(); - let after_ok = i + name_lower.len() >= expr.len() - || !expr - .chars() - .nth(i + name_lower.len()) - .unwrap_or(' ') - .is_alphanumeric(); - if before_ok && after_ok { - return true; - } - } - } - } - false - }; - - // Generate replacements based on whether measure is derived - // (start, end, replacement) - for derived measures, we remove the whole column + // Replace measured expressions: + // - decomposable measures become NULL placeholders (virtual columns) + // - non-decomposable measures keep their aggregate expression for direct querying let mut replacements: Vec<(usize, usize, String)> = Vec::new(); + let mut has_materialized_non_decomposable = false; for info in &measure_infos { - let should_exclude = is_derived(&info.expression); - - if should_exclude { - // Derived measure: remove entire column including preceding comma - let mut remove_start = info.expr_start; - // Look for preceding comma - let before = &sql[..remove_start]; - if let Some(comma_pos) = before.rfind(',') { - // Check that only whitespace between comma and expr_start - let between = &before[comma_pos + 1..]; - if between.trim().is_empty() { - remove_start = comma_pos; - } - } - replacements.push((remove_start, info.name_end, String::new())); - } else { - // Base measure: replace "AS MEASURE name" with "AS name" - let chunk = &sql[info.expr_start..info.name_end]; - if let Some(am_pos) = chunk.to_uppercase().find(" AS MEASURE ") { - let abs_start = info.expr_start + am_pos; - replacements.push((abs_start, info.name_end, format!(" AS {}", info.name))); - } + let is_non_decomp = is_non_decomposable(&info.expression); + if is_non_decomp { + has_materialized_non_decomposable = true; } + replacements.push(( + info.expr_start, + info.name_end, + if is_non_decomp { + format!("{} AS {}", info.expression.trim(), info.name) + } else { + format!("NULL AS {}", info.name) + }, + )); } // Build measures list @@ -2635,23 +3397,17 @@ fn extract_measures_from_sql( ); } - // If there are aggregate measures but no GROUP BY, add GROUP BY ALL - // This enables the "extension" syntax from the paper where views define - // measures without explicit grouping - let has_aggregate_measure = measures - .iter() - .any(|m| find_aggregation_in_expression(&m.expression).is_some()); - let clean_sql_upper = clean_sql.to_uppercase(); - let has_group_by = has_top_level_group_by(&clean_sql); + clean_sql = rewrite_percentile_within_group(&clean_sql); - if has_aggregate_measure && !has_group_by { - // Find insertion point: before ORDER BY, LIMIT, or at end - let insert_pos = ["ORDER BY", "LIMIT", ";"] + // Non-decomposable measures (e.g., COUNT DISTINCT, MEDIAN) kept as aggregates + // require grouping to form a valid view if dimensions are projected. + if has_materialized_non_decomposable && !has_group_by_anywhere(&clean_sql) { + let upper = clean_sql.to_uppercase(); + let insert_pos = ["ORDER BY", "LIMIT", "HAVING", ";"] .iter() - .filter_map(|kw| clean_sql_upper.find(kw)) + .filter_map(|kw| upper.find(kw)) .min() .unwrap_or(clean_sql.len()); - clean_sql = format!( "{} GROUP BY ALL{}", clean_sql[..insert_pos].trim_end(), @@ -2663,8 +3419,6 @@ fn extract_measures_from_sql( ); } - clean_sql = rewrite_percentile_within_group(&clean_sql); - Ok((clean_sql, measures, view_name, base_table)) } @@ -3052,27 +3806,36 @@ fn correlation_exprs_for_dim( ) -> (String, String) { let dim_trim = dim.trim(); let dim_name = dim_trim.split('.').next_back().unwrap_or(dim_trim).trim(); + let dim_is_qualified = dim_trim.contains('.'); let dim_key = normalize_dimension_key(dim_name); if let Some(expr) = dimension_exprs.get(&dim_key) { let inner_expr = qualify_where_for_inner(expr); - let outer_expr = outer_alias - .map(|alias| format!("{alias}.{dim_name}")) - .unwrap_or_else(|| dim_name.to_string()); + let outer_expr = if dim_is_qualified { + dim_trim.to_string() + } else { + outer_alias + .map(|alias| format!("{alias}.{dim_name}")) + .unwrap_or_else(|| dim_name.to_string()) + }; return (inner_expr, outer_expr); } if dim_trim.contains('(') { let inner_expr = qualify_where_for_inner(dim_trim); let outer_expr = outer_alias - .map(|alias| qualify_where_for_outer(dim_trim, alias)) + .map(|alias| format!("ANY_VALUE({})", qualify_where_for_outer(dim_trim, alias))) .unwrap_or_else(|| dim_trim.to_string()); return (inner_expr, outer_expr); } let inner_expr = format!("_inner.{dim_name}"); - let outer_expr = outer_alias - .map(|alias| format!("{alias}.{dim_name}")) - .unwrap_or_else(|| dim_name.to_string()); + let outer_expr = if dim_is_qualified { + dim_trim.to_string() + } else { + outer_alias + .map(|alias| format!("{alias}.{dim_name}")) + .unwrap_or_else(|| dim_name.to_string()) + }; (inner_expr, outer_expr) } @@ -3090,6 +3853,32 @@ struct NonDecompJoinPlan { replacement: String, } +fn expand_non_decomposable_default_context( + expression: &str, + base_relation_sql: &str, + outer_alias: Option<&str>, + group_by_cols: &[String], + dimension_exprs: &HashMap, +) -> String { + let base_relation = base_relation_for_subquery(base_relation_sql); + + if group_by_cols.is_empty() { + return format!("(SELECT {expression} FROM {base_relation})"); + } + + let where_clauses: Vec<_> = group_by_cols + .iter() + .map(|col| correlation_condition_for_dim(col, dimension_exprs, outer_alias)) + .collect(); + + format!( + "(SELECT {} FROM {} _inner WHERE {})", + expression, + base_relation, + where_clauses.join(" AND ") + ) +} + fn build_non_decomposable_join_plan( expression: &str, base_relation_sql: &str, @@ -3145,8 +3934,9 @@ fn build_non_decomposable_join_plan( ContextModifier::Visible => { if !has_set && !has_all_global { if let Some(w) = outer_where { + let stripped = strip_at_where_qualifiers(w); effective_where = Some(qualify_where_for_inner_with_dimensions( - w, + &stripped, dimension_exprs, )); } @@ -3166,14 +3956,16 @@ fn build_non_decomposable_join_plan( let dim_name = dim.split('.').next_back().unwrap_or(dim).trim(); let dim_key = normalize_dimension_key(dim_name); if !has_all_global && !removed_dims.contains(&dim_key) { + let resolved_expr = + resolve_current_in_expr(expr, group_by_cols, outer_where, outer_alias); let outer_expr = if let Some(alias) = outer_alias { if dim.contains('(') { - qualify_where_for_outer(expr, alias) + qualify_where_for_outer(&resolved_expr, alias) } else { - qualify_outer_reference(expr, alias, dim_name) + qualify_outer_reference(&resolved_expr, alias, dim_name) } } else { - expr.to_string() + resolved_expr }; set_overrides.insert(dim_key, outer_expr); } @@ -3383,8 +4175,11 @@ fn expand_non_decomposable_to_sql( ContextModifier::Visible => { if !has_set && !has_all_global { if let Some(w) = outer_where { - effective_where = - Some(qualify_where_for_inner_with_dimensions(w, dimension_exprs)); + let stripped = strip_at_where_qualifiers(w); + effective_where = Some(qualify_where_for_inner_with_dimensions( + &stripped, + dimension_exprs, + )); } } } @@ -3404,10 +4199,16 @@ fn expand_non_decomposable_to_sql( let outer_ref = outer_alias.unwrap_or("_outer"); let dim_name = dim.split('.').next_back().unwrap_or(dim).trim(); let dim_key = normalize_dimension_key(dim_name); + let resolved_expr = resolve_current_in_expr( + expr, + group_by_cols, + outer_where, + Some(outer_ref), + ); let qualified_expr = if dim.contains('(') { - qualify_where_for_outer(expr, outer_ref) + qualify_where_for_outer(&resolved_expr, outer_ref) } else { - qualify_outer_reference(expr, outer_ref, dim_name) + qualify_outer_reference(&resolved_expr, outer_ref, dim_name) }; let inner_dim = if let Some(expr) = dimension_exprs.get(&dim_key) { qualify_where_for_inner(expr) @@ -3513,10 +4314,12 @@ fn expand_non_decomposable_at_to_sql( let outer_ref = outer_alias.unwrap_or("_outer"); let dim_name = dim.split('.').next_back().unwrap_or(dim).trim(); let dim_key = normalize_dimension_key(dim_name); + let resolved_expr = + resolve_current_in_expr(expr, group_by_cols, outer_where, Some(outer_ref)); let qualified_expr = if dim.contains('(') { - qualify_where_for_outer(expr, outer_ref) + qualify_where_for_outer(&resolved_expr, outer_ref) } else { - qualify_outer_reference(expr, outer_ref, dim_name) + qualify_outer_reference(&resolved_expr, outer_ref, dim_name) }; let inner_dim = if let Some(expr) = dimension_exprs.get(&dim_key) { qualify_where_for_inner(expr) @@ -3564,8 +4367,9 @@ fn expand_non_decomposable_at_to_sql( if group_by_cols.is_empty() { match outer_where { Some(w) => { + let stripped = strip_at_where_qualifiers(w); let qualified = - qualify_where_for_inner_with_dimensions(w, dimension_exprs); + qualify_where_for_inner_with_dimensions(&stripped, dimension_exprs); format!( "(SELECT {expression} FROM {base_relation} _inner WHERE {qualified})" ) @@ -3579,8 +4383,11 @@ fn expand_non_decomposable_at_to_sql( .collect(); let full_where = match outer_where { Some(w) => { - let qualified = - qualify_where_for_inner_with_dimensions(w, dimension_exprs); + let stripped = strip_at_where_qualifiers(w); + let qualified = qualify_where_for_inner_with_dimensions( + &stripped, + dimension_exprs, + ); format!("{} AND {}", where_clauses.join(" AND "), qualified) } None => where_clauses.join(" AND "), @@ -3669,14 +4476,21 @@ pub fn expand_at_to_sql( ContextModifier::Set(dim, expr) => { // Use outer_alias for the correlated reference, falling back to table_name let outer_ref = outer_alias.unwrap_or(table_name); - let qualified_expr = qualify_outer_reference(expr, outer_ref, dim); + let dim_name = dim.split('.').next_back().unwrap_or(dim).trim(); + let resolved_expr = + resolve_current_in_expr(expr, group_by_cols, outer_where, Some(outer_ref)); + let qualified_expr = if dim.contains('(') { + qualify_where_for_outer(&resolved_expr, outer_ref) + } else { + qualify_outer_reference(&resolved_expr, outer_ref, dim_name) + }; // For ad hoc dimensions (expressions like MONTH(date)), qualify column refs with _inner // For simple columns, just prefix with _inner. let inner_dim = if dim.contains('(') { // Expression: qualify column refs inside it qualify_where_for_inner(dim) } else { - format!("_inner.{dim}") + format!("_inner.{dim_name}") }; // SET condition for the specified dimension @@ -3896,8 +4710,8 @@ pub fn expand_modifiers_to_sql( // Per paper: SET bypasses outer WHERE, so VISIBLE has no effect when SET is present if !has_set && !has_all_global { if let Some(w) = outer_where { - // Qualify column references with _inner - effective_where = Some(qualify_where_for_inner(w)); + let stripped = strip_at_where_qualifiers(w); + effective_where = Some(qualify_where_for_inner(&stripped)); } } } @@ -3913,12 +4727,23 @@ pub fn expand_modifiers_to_sql( let dim_lower = dim.to_lowercase(); if !has_all_global && !removed_dims.contains(&dim_lower) { let outer_ref = outer_alias.unwrap_or(table_name); - let qualified_expr = qualify_outer_reference(expr, outer_ref, dim); + let dim_name = dim.split('.').next_back().unwrap_or(dim).trim(); + let resolved_expr = resolve_current_in_expr( + expr, + group_by_cols, + outer_where, + Some(outer_ref), + ); + let qualified_expr = if dim.contains('(') { + qualify_where_for_outer(&resolved_expr, outer_ref) + } else { + qualify_outer_reference(&resolved_expr, outer_ref, dim_name) + }; // For ad hoc dimensions (expressions), qualify column refs inside let inner_dim = if dim.contains('(') { qualify_where_for_inner(dim) } else { - format!("_inner.{dim}") + format!("_inner.{dim_name}") }; set_conditions.push(format!("{inner_dim} IS NOT DISTINCT FROM {qualified_expr}")); } @@ -4087,7 +4912,8 @@ fn expand_modifiers_to_sql_derived( // Per paper: SET bypasses outer WHERE, so VISIBLE has no effect when SET is present if !has_set && !has_all_global { if let Some(w) = outer_where { - effective_where = Some(qualify_where_for_inner(w)); + let stripped = strip_at_where_qualifiers(w); + effective_where = Some(qualify_where_for_inner(&stripped)); } } } @@ -4166,7 +4992,7 @@ fn validate_set_expression_requirements( continue; } let dim_name = dim.split('.').next_back().unwrap_or(dim).trim(); - if expr_mentions_identifier(expr, dim_name) + if expr_mentions_identifier_outside_current(expr, dim_name) && !dimension_in_group_by(dim, group_by_cols, default_qualifier) { return Some(format!( @@ -4183,9 +5009,19 @@ fn validate_set_expression_requirements( /// Expand AGGREGATE() with AT modifiers in SQL pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { let cte_expansion = expand_cte_queries(sql); - let sql = cte_expansion.sql; + let mut sql = cte_expansion.sql; let mut had_aggregate = cte_expansion.had_aggregate; + if has_measure_at_refs(&sql) { + sql = rewrite_measure_at_refs(&sql); + had_aggregate = true; + } + + if has_implicit_measure_refs(&sql) { + sql = rewrite_implicit_measure_refs(&sql); + had_aggregate = true; + } + // Check if we need the full expansion path (AT modifiers or non-decomposable measures) let has_aggregate = has_aggregate_function(&sql); @@ -4202,25 +5038,30 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { let at_patterns = extract_aggregate_with_at_full(&sql); // Keep full expansion path even without AT to handle non-decomposable measures safely - // Extract table info using string-based approach (works with AGGREGATE syntax) - // Note: DuckDB's parser can't parse AGGREGATE() since it's our custom syntax - let (primary_table_name, existing_alias) = - extract_table_and_alias_from_sql(&sql).unwrap_or_else(|| ("t".to_string(), None)); - - // Build FromClauseInfo from string-based extraction for now - // TODO: For proper JOIN support, we'd need to extract all tables from the FROM clause - let mut from_info = FromClauseInfo::default(); - let primary_table = TableInfo { - name: primary_table_name.clone(), - effective_name: existing_alias - .clone() - .unwrap_or_else(|| primary_table_name.clone()), - has_alias: existing_alias.is_some(), + // Prefer parser-FFI FROM extraction (supports JOIN aliases when SQL parses there), + // then fall back to string extraction for AGGREGATE/AT syntax that parser-FFI cannot parse. + let mut from_info = extract_from_clause_info_ffi(&sql); + let (primary_table_name, existing_alias) = if let Some(pt) = from_info.primary_table.clone() { + let alias = if pt.has_alias { + Some(pt.effective_name.clone()) + } else { + None + }; + (pt.name, alias) + } else { + let (table_name, alias) = + extract_table_and_alias_from_sql(&sql).unwrap_or_else(|| ("t".to_string(), None)); + let primary_table = TableInfo { + name: table_name.clone(), + effective_name: alias.clone().unwrap_or_else(|| table_name.clone()), + has_alias: alias.is_some(), + }; + from_info + .tables + .insert(primary_table.effective_name.clone(), primary_table.clone()); + from_info.primary_table = Some(primary_table); + (table_name, alias) }; - from_info - .tables - .insert(primary_table.effective_name.clone(), primary_table.clone()); - from_info.primary_table = Some(primary_table); // Extract outer WHERE clause for VISIBLE semantics let outer_where = extract_where_clause(&sql); @@ -4251,9 +5092,17 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { // Extract dimension columns from original SQL for implicit GROUP BY // (must be done before expansion since expanded SQL has SUM() etc) let original_dim_cols = extract_dimension_columns_from_select(&sql); + let effective_group_by_cols = if group_by_cols.is_empty() { + original_dim_cols.clone() + } else { + group_by_cols.clone() + }; + + let has_expression_dimensions = original_dim_cols.iter().any(|col| col.contains('(')); - // Check if any AT modifier needs correlation (for alias handling) - let needs_outer_alias = at_patterns.iter().any(|(_, modifiers, _, _)| { + // Check if query needs an explicit outer alias for correlation handling. + let needs_outer_alias = has_expression_dimensions + || at_patterns.iter().any(|(_, modifiers, _, _)| { modifiers.iter().any(|m| { matches!(m, ContextModifier::Set(_, _)) || matches!(m, ContextModifier::All(_)) @@ -4282,20 +5131,52 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { None }; - let mut join_clauses: Vec = Vec::new(); - let mut join_counter = 0; - let mut patterns = at_patterns; patterns.sort_by(|a, b| b.2.cmp(&a.2)); for (measure_name, modifiers, start, end) in patterns { + let measure_lookup_name = strip_measure_qualifier(&measure_name); // Look up which view contains this measure (for JOIN support) - let resolved = resolve_measure_source(&measure_name, &primary_table_name); - let measure_group_by_cols = filter_group_by_cols_for_measure( - &group_by_cols, + let resolved = resolve_measure_source(&measure_lookup_name, &primary_table_name); + let mut measure_group_by_cols = filter_group_by_cols_for_measure( + &effective_group_by_cols, &resolved.view_group_by_cols, &resolved.dimension_exprs, ); + let mut allowed_qualifiers: HashSet = HashSet::new(); + allowed_qualifiers.insert(normalize_identifier_name(&resolved.source_view)); + if let Some(alias) = find_alias_for_view(&from_info, &resolved.source_view) { + allowed_qualifiers.insert(normalize_identifier_name(alias)); + } + if let Some(ref pt) = from_info.primary_table { + if pt.name.eq_ignore_ascii_case(&resolved.source_view) { + if let Some(alias) = primary_alias.as_deref() { + allowed_qualifiers.insert(normalize_identifier_name(alias)); + } + } + } + let source_dims = source_dimension_names(&resolved.source_view); + measure_group_by_cols.retain(|col| { + if let Some((Some(qualifier), dim_name)) = parse_simple_measure_ref(col) { + return allowed_qualifiers.contains(&qualifier) || source_dims.contains(&dim_name); + } + let dim_key = normalize_dimension_key(col); + source_dims.is_empty() + || source_dims.contains(&dim_key) + || resolved.dimension_exprs.contains_key(&dim_key) + || source_dims + .iter() + .any(|dim_name| expr_mentions_identifier(col, dim_name)) + }); + let eval_group_by_cols = + if measure_group_by_cols.is_empty() + && !original_dim_cols.is_empty() + && resolved.source_view.eq_ignore_ascii_case(&primary_table_name) + { + original_dim_cols.clone() + } else { + measure_group_by_cols.clone() + }; // Non-decomposable measures are recomputed from base rows (including AT modifiers) @@ -4316,88 +5197,47 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { }; let outer_alias_ref = outer_alias.as_deref(); - // For derived measures, use the expanded expression instead of measure_name - let (effective_measure, effective_agg) = if let Some(ref expr) = resolved.derived_expr { - // Derived measure: use expanded expression directly (no wrapping AGG) - (expr.as_str(), "".to_string()) - } else { - (measure_name.as_str(), resolved.agg_fn.clone()) - }; - - let expanded = if resolved.derived_expr.is_some() { - // For derived measures, build the subquery with the expanded expression - expand_modifiers_to_sql_derived( - effective_measure, - &modifiers, - &resolved.source_view, - outer_alias_ref, - outer_where_ref, - &measure_group_by_cols, - ) - } else if !resolved.is_decomposable { - let outer_ref_for_non_decomp = - outer_alias_ref.or(Some(resolved.source_view.as_str())); - let base_relation_sql = resolved - .base_relation_sql - .clone() - .or_else(|| { - resolved - .base_table - .clone() - .map(|table| format!("SELECT * FROM {table}")) - }) - .unwrap_or_else(|| format!("SELECT * FROM {}", resolved.source_view)); + let outer_ref_for_eval = outer_alias_ref.or(Some(resolved.source_view.as_str())); + let base_relation_sql = resolved + .base_relation_sql + .clone() + .or_else(|| { + resolved + .base_table + .clone() + .map(|table| format!("SELECT * FROM {table}")) + }) + .unwrap_or_else(|| format!("SELECT * FROM {}", resolved.source_view)); - let (expanded, already_aggregated) = - if modifiers.is_empty() && can_use_view_measure_directly(&resolved, &group_by_cols) - { - let measure_ref = outer_ref_for_non_decomp - .map(|alias| format!("{alias}.{measure_name}")) - .unwrap_or_else(|| measure_name.to_string()); - (format!("MAX({measure_ref})"), true) - } else if let Some(plan) = build_non_decomposable_join_plan( - &resolved.expression, - &base_relation_sql, - outer_ref_for_non_decomp, - outer_where_ref, - &measure_group_by_cols, - &modifiers, - &resolved.dimension_exprs, - &format!("_nd_{join_counter}"), - ) { - join_clauses.push(plan.join_sql); - join_counter += 1; - (plan.replacement, false) - } else { - // Non-decomposable measure: expand against base table with correlation - ( - expand_non_decomposable_to_sql( - &resolved.expression, - &base_relation_sql, - outer_ref_for_non_decomp, - outer_where_ref, - &measure_group_by_cols, - &modifiers, - &resolved.dimension_exprs, - ), - false, - ) - }; + let expression_for_eval = resolved + .derived_expr + .clone() + .unwrap_or_else(|| resolved.expression.clone()); - if original_dim_cols.is_empty() && !already_aggregated { - format!("MAX({expanded})") + let expanded = if !expression_for_eval.is_empty() { + let eval_sql = expand_non_decomposable_to_sql( + &expression_for_eval, + &base_relation_sql, + outer_ref_for_eval, + outer_where_ref, + &eval_group_by_cols, + &modifiers, + &resolved.dimension_exprs, + ); + if original_dim_cols.is_empty() { + format!("MAX({eval_sql})") } else { - expanded + eval_sql } } else { expand_modifiers_to_sql( - &measure_name, - &effective_agg, + &measure_lookup_name, + &resolved.agg_fn, &modifiers, &resolved.source_view, outer_alias_ref, outer_where_ref, - &measure_group_by_cols, + &eval_group_by_cols, ) }; result_sql = format!("{}{}{}", &result_sql[..start], expanded, &result_sql[end..]); @@ -4408,86 +5248,121 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { plain_calls.sort_by(|a, b| b.1.cmp(&a.1)); // Sort by position descending for (measure_name, start, end) in plain_calls { - let resolved = resolve_measure_source(&measure_name, &primary_table_name); - let measure_group_by_cols = filter_group_by_cols_for_measure( - &group_by_cols, + let mut replacement_end = end; + let mut use_default_context = false; + let suffix = &result_sql[end..]; + let leading_ws = suffix.len() - suffix.trim_start().len(); + if suffix[leading_ws..].starts_with(DEFAULT_CONTEXT_MARKER) { + use_default_context = true; + replacement_end = end + leading_ws + DEFAULT_CONTEXT_MARKER.len(); + } + + let measure_lookup_name = strip_measure_qualifier(&measure_name); + let resolved = resolve_measure_source(&measure_lookup_name, &primary_table_name); + let mut measure_group_by_cols = filter_group_by_cols_for_measure( + &effective_group_by_cols, &resolved.view_group_by_cols, &resolved.dimension_exprs, ); + let mut allowed_qualifiers: HashSet = HashSet::new(); + allowed_qualifiers.insert(normalize_identifier_name(&resolved.source_view)); + if let Some(alias) = find_alias_for_view(&from_info, &resolved.source_view) { + allowed_qualifiers.insert(normalize_identifier_name(alias)); + } + if let Some(ref pt) = from_info.primary_table { + if pt.name.eq_ignore_ascii_case(&resolved.source_view) { + if let Some(alias) = primary_alias.as_deref() { + allowed_qualifiers.insert(normalize_identifier_name(alias)); + } + } + } + let source_dims = source_dimension_names(&resolved.source_view); + measure_group_by_cols.retain(|col| { + if let Some((Some(qualifier), dim_name)) = parse_simple_measure_ref(col) { + return allowed_qualifiers.contains(&qualifier) || source_dims.contains(&dim_name); + } + let dim_key = normalize_dimension_key(col); + source_dims.is_empty() + || source_dims.contains(&dim_key) + || resolved.dimension_exprs.contains_key(&dim_key) + || source_dims + .iter() + .any(|dim_name| expr_mentions_identifier(col, dim_name)) + }); + let eval_group_by_cols = + if measure_group_by_cols.is_empty() + && !original_dim_cols.is_empty() + && resolved.source_view.eq_ignore_ascii_case(&primary_table_name) + { + original_dim_cols.clone() + } else { + measure_group_by_cols.clone() + }; - // For derived measures, use the expanded expression; otherwise use AGG(measure_name) - let expanded = if let Some(expr) = resolved.derived_expr { - expr - } else if !resolved.is_decomposable { - let outer_ref_for_non_decomp = - primary_alias.as_deref().or(Some(resolved.source_view.as_str())); - let base_relation_sql = resolved - .base_relation_sql - .clone() - .or_else(|| { - resolved - .base_table - .clone() - .map(|table| format!("SELECT * FROM {table}")) + let outer_alias = if let Some(ref pt) = from_info.primary_table { + if pt.name.eq_ignore_ascii_case(&resolved.source_view) { + primary_alias.clone().or_else(|| { + find_alias_for_view(&from_info, &resolved.source_view).map(|s| s.to_string()) }) - .unwrap_or_else(|| format!("SELECT * FROM {}", resolved.source_view)); - let (expanded, already_aggregated) = - if can_use_view_measure_directly(&resolved, &group_by_cols) { - let measure_ref = outer_ref_for_non_decomp - .map(|alias| format!("{alias}.{measure_name}")) - .unwrap_or_else(|| measure_name.to_string()); - (format!("MAX({measure_ref})"), true) - } else if let Some(plan) = build_non_decomposable_join_plan( - &resolved.expression, + } else { + find_alias_for_view(&from_info, &resolved.source_view) + .map(|s| s.to_string()) + .or_else(|| primary_alias.clone()) + } + } else { + primary_alias.clone() + }; + let outer_ref_for_eval = outer_alias.as_deref().or(Some(resolved.source_view.as_str())); + let base_relation_sql = resolved + .base_relation_sql + .clone() + .or_else(|| { + resolved + .base_table + .clone() + .map(|table| format!("SELECT * FROM {table}")) + }) + .unwrap_or_else(|| format!("SELECT * FROM {}", resolved.source_view)); + let expression_for_eval = resolved + .derived_expr + .clone() + .unwrap_or_else(|| resolved.expression.clone()); + + let expanded = if !expression_for_eval.is_empty() { + let eval_sql = if use_default_context { + expand_non_decomposable_default_context( + &expression_for_eval, + &base_relation_sql, + outer_ref_for_eval, + &eval_group_by_cols, + &resolved.dimension_exprs, + ) + } else { + expand_non_decomposable_to_sql( + &expression_for_eval, &base_relation_sql, - outer_ref_for_non_decomp, + outer_ref_for_eval, outer_where_ref, - &measure_group_by_cols, - &[], // No modifiers for plain AGGREGATE() + &eval_group_by_cols, + &[], // No modifiers for explicit AGGREGATE() &resolved.dimension_exprs, - &format!("_nd_{join_counter}"), - ) { - join_clauses.push(plan.join_sql); - join_counter += 1; - (plan.replacement, false) - } else { - // Non-decomposable measure: expand against base table with correlation - ( - expand_non_decomposable_to_sql( - &resolved.expression, - &base_relation_sql, - outer_ref_for_non_decomp, - outer_where_ref, - &measure_group_by_cols, - &[], // No modifiers for plain AGGREGATE() - &resolved.dimension_exprs, - ), - false, - ) - }; - - if original_dim_cols.is_empty() && !already_aggregated { - format!("MAX({expanded})") + ) + }; + if original_dim_cols.is_empty() { + format!("MAX({eval_sql})") } else { - expanded + eval_sql } } else { - format!("{}({measure_name})", resolved.agg_fn) + format!("{}({measure_lookup_name})", resolved.agg_fn) }; - result_sql = format!("{}{}{}", &result_sql[..start], expanded, &result_sql[end..]); - } - - if !join_clauses.is_empty() { - if let Some(insert_pos) = find_from_clause_end(&result_sql) { - let joins = join_clauses.join(""); - result_sql = format!( - "{}{}{}", - &result_sql[..insert_pos], - joins, - &result_sql[insert_pos..] - ); - } + result_sql = format!( + "{}{}{}", + &result_sql[..start], + expanded, + &result_sql[replacement_end..] + ); } if insert_outer_alias { @@ -4606,6 +5481,26 @@ pub fn get_measure_aggregation(column_name: &str) -> Option<(String, String)> { } fn extract_group_by_columns(sql: &str) -> Vec { + if let Ok(info) = parser_ffi::parse_select(sql) { + if info.has_group_by { + if info.group_by_all { + return extract_dimension_columns_from_select(sql); + } + if !info.group_by_cols.is_empty() { + let parser_group_by_cols: Vec = info + .group_by_cols + .into_iter() + .map(|c| c.trim().to_string()) + .filter(|c| !c.is_empty() && !c.chars().all(|ch| ch.is_ascii_digit())) + .collect(); + + if !parser_group_by_cols.is_empty() { + return parser_group_by_cols; + } + } + } + } + let mut columns = Vec::new(); let query = sql.trim().trim_end_matches(';').trim(); @@ -4695,6 +5590,55 @@ fn is_literal_constant(expr: &str) -> bool { } /// Extract non-AGGREGATE columns from SELECT clause to use as implicit GROUP BY columns +fn looks_like_sql_aggregate_expr(expr: &str) -> bool { + // Prefer parser-backed detection when parser FFI is available. + let parsed = std::panic::catch_unwind(|| parser_ffi::parse_expression(expr)) + .ok() + .and_then(|result| result.ok()); + if let Some(info) = parsed { + if info.is_aggregate { + return true; + } + } + + // Fallback heuristic for environments where parser FFI is unavailable. + let upper = expr.to_uppercase(); + [ + "COUNT(", + "COUNT_STAR(", + "SUM(", + "AVG(", + "MIN(", + "MAX(", + "ANY_VALUE(", + "STRING_AGG(", + "ARRAY_AGG(", + "LIST(", + "FIRST(", + "LAST(", + "MEDIAN(", + "MODE(", + "STDDEV(", + "STDDEV_POP(", + "STDDEV_SAMP(", + "VAR_POP(", + "VAR_SAMP(", + "VARIANCE(", + "QUANTILE(", + "QUANTILE_CONT(", + "QUANTILE_DISC(", + "PERCENTILE_CONT(", + "PERCENTILE_DISC(", + "BOOL_AND(", + "BOOL_OR(", + "BIT_AND(", + "BIT_OR(", + "BIT_XOR(", + ] + .iter() + .any(|pattern| upper.contains(pattern)) +} + fn extract_dimension_columns_from_select(sql: &str) -> Vec { let mut columns = Vec::new(); @@ -4753,6 +5697,9 @@ fn extract_dimension_columns_from_select(sql: &str) -> Vec { } else { item.trim() }; + if looks_like_sql_aggregate_expr(col) { + continue; + } if !col.is_empty() && !is_literal_constant(col) { columns.push(col.to_string()); } @@ -4813,6 +5760,14 @@ mod tests { assert!(cols.is_empty()); } + #[test] + fn test_extract_dimension_columns_excludes_standard_aggregates() { + let cols = extract_dimension_columns_from_select( + "SELECT region, COUNT(*) AS c, SUM(amount) AS s, ANY_VALUE(flag) AS f FROM sales GROUP BY ROLLUP(region)", + ); + assert_eq!(cols, vec!["region".to_string()]); + } + #[test] fn test_extract_dimension_columns_keeps_non_aggregate_suffix() { let cols = extract_dimension_columns_from_select( @@ -5211,6 +6166,14 @@ FROM orders"#; let sql2 = "SELECT status, region FROM t GROUP BY status, region ORDER BY status"; let cols2 = extract_group_by_columns(sql2); assert_eq!(cols2, vec!["status".to_string(), "region".to_string()]); + + let sql3 = "SELECT status, AGGREGATE(revenue) FROM orders_summary GROUP BY 1"; + let cols3 = extract_group_by_columns(sql3); + assert_eq!(cols3, vec!["status".to_string()]); + + let sql4 = "SELECT status, region, AGGREGATE(revenue) FROM orders_summary GROUP BY 1, 2"; + let cols4 = extract_group_by_columns(sql4); + assert_eq!(cols4, vec!["status".to_string(), "region".to_string()]); } #[test] @@ -5221,6 +6184,72 @@ FROM orders"#; ); } + #[test] + #[ignore = "requires C++ parser FFI"] + #[serial] + fn test_rewrite_measure_at_refs() { + clear_measure_views(); + store_measure_view( + "sales_v", + vec![ViewMeasure { + column_name: "revenue".to_string(), + expression: "SUM(amount)".to_string(), + is_decomposable: true, + }], + "SELECT year, SUM(amount) AS revenue FROM sales GROUP BY year", + Some("sales".to_string()), + ); + + let sql = "SELECT revenue AT (VISIBLE) FROM sales_v"; + let rewritten = rewrite_measure_at_refs(sql); + assert_eq!(rewritten, "SELECT AGGREGATE(revenue) AT (VISIBLE) FROM sales_v"); + assert!(has_measure_at_refs(sql)); + } + + #[test] + #[ignore = "requires C++ parser FFI"] + #[serial] + fn test_rewrite_implicit_measure_refs_uses_default_context() { + clear_measure_views(); + store_measure_view( + "sales_v", + vec![ViewMeasure { + column_name: "revenue".to_string(), + expression: "SUM(amount)".to_string(), + is_decomposable: true, + }], + "SELECT year, SUM(amount) AS revenue FROM sales GROUP BY year", + Some("sales".to_string()), + ); + + let sql = "SELECT year, revenue FROM sales_v GROUP BY year"; + let rewritten = rewrite_implicit_measure_refs(sql); + assert!(rewritten.contains("AGGREGATE(revenue) /*YARDSTICK_DEFAULT*/")); + } + + #[test] + fn test_rewrite_implicit_measure_refs_fallback_with_at_syntax() { + let known_measures = HashSet::from([normalize_identifier_name("sumRevenue")]); + let sql = "SELECT o.prodName, COUNT(*) AS c, o.sumRevenue AT (VISIBLE) AS rViz, o.sumRevenue AS r FROM paper_orders_v o GROUP BY ROLLUP(o.prodName)"; + let rewritten = rewrite_implicit_measure_refs_fallback(sql, &known_measures); + + assert!(rewritten.contains("o.sumRevenue AT (VISIBLE)")); + assert!(rewritten.contains("AGGREGATE(o.sumRevenue) /*YARDSTICK_DEFAULT*/ AS r")); + } + + #[test] + fn test_has_implicit_measure_refs_fallback_ignores_count() { + let known_measures = HashSet::from([normalize_identifier_name("sumRevenue")]); + let sql = "SELECT o.prodName, COUNT(*) AS c, o.sumRevenue AT (VISIBLE) AS rViz, o.sumRevenue AS r FROM paper_orders_v o GROUP BY ROLLUP(o.prodName)"; + assert!(has_implicit_measure_refs_fallback(sql, &known_measures)); + + let sql_no_plain_measure = "SELECT o.prodName, COUNT(*) AS c, o.sumRevenue AT (VISIBLE) AS rViz FROM paper_orders_v o GROUP BY ROLLUP(o.prodName)"; + assert!(!has_implicit_measure_refs_fallback( + sql_no_plain_measure, + &known_measures + )); + } + #[test] fn test_extract_view_name() { assert_eq!( @@ -5355,30 +6384,57 @@ FROM orders"#; #[test] fn test_parse_current_in_expr() { - // CURRENT year - 1 should become year - 1 - assert_eq!(parse_current_in_expr("CURRENT year - 1"), "year - 1"); - - // Multiple CURRENT references + assert_eq!( + parse_current_in_expr("CURRENT year - 1"), + "CURRENT year - 1" + ); assert_eq!( parse_current_in_expr("CURRENT year + CURRENT month"), - "year + month" + "CURRENT year + CURRENT month" ); - - // Mixed case - assert_eq!(parse_current_in_expr("current YEAR - 1"), "YEAR - 1"); - - // No CURRENT assert_eq!(parse_current_in_expr("year - 1"), "year - 1"); } #[test] fn test_parse_at_modifier_with_current() { - // CURRENT should be stripped from SET expression let result = parse_at_modifier("SET year = CURRENT year - 1").unwrap(); assert_eq!( result, - ContextModifier::Set("year".to_string(), "year - 1".to_string()) + ContextModifier::Set("year".to_string(), "CURRENT year - 1".to_string()) + ); + } + + #[test] + fn test_resolve_current_in_expr_uses_null_when_unconstrained() { + let resolved = resolve_current_in_expr( + "CURRENT year - 1", + &[String::from("region")], + None, + None, + ); + assert_eq!(resolved, "NULL - 1"); + } + + #[test] + fn test_resolve_current_in_expr_uses_dimension_when_grouped() { + let resolved = resolve_current_in_expr( + "CURRENT year - 1", + &[String::from("year")], + None, + None, + ); + assert_eq!(resolved, "year - 1"); + } + + #[test] + fn test_resolve_current_in_expr_uses_dimension_when_where_constrained() { + let resolved = resolve_current_in_expr( + "CURRENT year - 1", + &[], + Some("year = 2024"), + None, ); + assert_eq!(resolved, "year - 1"); } #[test] diff --git a/yardstick-rs/src/sql/mod.rs b/yardstick-rs/src/sql/mod.rs index acc24ad..f559efd 100644 --- a/yardstick-rs/src/sql/mod.rs +++ b/yardstick-rs/src/sql/mod.rs @@ -15,6 +15,8 @@ pub use measures::{ has_as_measure, has_at_syntax, has_curly_brace_measure, + has_implicit_measure_refs, + has_measure_at_refs, process_create_view, // Core types AggregateExpandResult,