diff --git a/yardstick-rs/src/sql/measures.rs b/yardstick-rs/src/sql/measures.rs index 1029cdb..231c084 100644 --- a/yardstick-rs/src/sql/measures.rs +++ b/yardstick-rs/src/sql/measures.rs @@ -1752,6 +1752,81 @@ pub fn qualify_outer_reference(expr: &str, table_name: &str, dim: &str) -> Strin result } +fn expr_mentions_identifier(expr: &str, ident: &str) -> bool { + let mut chars = expr.chars().peekable(); + + while let Some(c) = chars.next() { + if c == '\'' { + // Skip SQL single-quoted string literal content, including escaped ''. + while let Some(next) = chars.next() { + if next == '\'' { + if chars.peek() == Some(&'\'') { + chars.next(); + } else { + break; + } + } + } + continue; + } + + if c.is_alphabetic() || c == '_' { + let mut token = String::from(c); + while let Some(&next) = chars.peek() { + if next.is_alphanumeric() || next == '_' { + token.push(chars.next().unwrap()); + } else { + break; + } + } + + if token.eq_ignore_ascii_case(ident) { + return true; + } + } + } + + false +} + +fn dimension_in_group_by( + dim: &str, + group_by_cols: &[String], + default_qualifier: Option<&str>, +) -> bool { + let dim_trim = dim.trim(); + let dim_name = dim_trim.split('.').next_back().unwrap_or(dim_trim).trim(); + let explicit_dim_qualifier = dim_trim + .rsplit_once('.') + .map(|(qualifier, _)| qualifier.trim()); + let expected_qualifier = explicit_dim_qualifier.or(default_qualifier); + let dim_lower = dim_trim.to_lowercase(); + + if dim_trim.contains('(') { + return group_by_cols + .iter() + .any(|col| col.to_lowercase() == dim_lower); + } + + group_by_cols.iter().any(|col| { + let col_trim = col.trim(); + let col_name = col_trim.split('.').next_back().unwrap_or(col_trim).trim(); + + if !col_name.eq_ignore_ascii_case(dim_name) { + return false; + } + + // When we know the expected outer qualifier, GROUP BY qualifiers must match it. + // Unqualified GROUP BY columns still count. + match (expected_qualifier, col_trim.rsplit_once('.')) { + (Some(expected), Some((col_qualifier, _))) => { + col_qualifier.trim().eq_ignore_ascii_case(expected) + } + _ => true, + } + }) +} + /// Qualify column references in a WHERE clause for use inside _inner subquery /// "region = 'US'" -> "_inner.region = 'US'" /// "year > 2020 AND region = 'US'" -> "_inner.year > 2020 AND _inner.region = 'US'" @@ -3862,6 +3937,32 @@ fn expand_modifiers_to_sql_derived( } } +fn validate_set_expression_requirements( + at_patterns: &[(String, Vec, usize, usize)], + group_by_cols: &[String], + default_qualifier: Option<&str>, +) -> Option { + for (_, modifiers, _, _) in at_patterns { + for modifier in modifiers { + if let ContextModifier::Set(dim, expr) = modifier { + if dim.contains('(') { + continue; + } + let dim_name = dim.split('.').next_back().unwrap_or(dim).trim(); + if expr_mentions_identifier(expr, dim_name) + && !dimension_in_group_by(dim, group_by_cols, default_qualifier) + { + return Some(format!( + "AT (SET {dim} = {expr}) references {dim_name}, but the query does not group by {dim_name}. Add {dim_name} to SELECT/GROUP BY or use a constant SET value." + )); + } + } + } + } + + None +} + /// Expand AGGREGATE() with AT modifiers in SQL pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { let cte_expansion = expand_cte_queries(sql); @@ -3911,6 +4012,25 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { // Extract GROUP BY columns for AT (ALL dim) correlation let group_by_cols = extract_group_by_columns(&sql); + let default_set_qualifier = if let Some(alias) = existing_alias.as_deref() { + Some(alias) + } else if primary_table_name.is_empty() { + None + } else { + Some(primary_table_name.as_str()) + }; + if let Some(error) = validate_set_expression_requirements( + &at_patterns, + &group_by_cols, + default_set_qualifier, + ) { + return AggregateExpandResult { + had_aggregate: true, + expanded_sql: sql.to_string(), + error: Some(error), + }; + } + // Extract dimension columns from original SQL for implicit GROUP BY // (must be done before expansion since expanded SQL has SUM() etc) let original_dim_cols = extract_dimension_columns_from_select(&sql); @@ -4780,6 +4900,59 @@ FROM orders"#; ); } + #[test] + fn test_set_expression_requires_grouped_dimension() { + let sql = + "SELECT region, AGGREGATE(revenue) AT (SET year = year - 1) FROM sales_v"; + let result = expand_aggregate_with_at(sql); + assert!(result.error.is_some()); + let message = result.error.unwrap(); + assert!(message.contains("SET year = year - 1")); + assert!(message.contains("GROUP BY")); + } + + #[test] + fn test_set_expression_allows_grouped_dimension() { + let sql = + "SELECT year, region, AGGREGATE(revenue) AT (SET year = year - 1) FROM sales_v"; + let result = expand_aggregate_with_at(sql); + assert!(result.error.is_none()); + } + + #[test] + fn test_set_constant_allows_ungrouped_dimension() { + let sql = + "SELECT region, AGGREGATE(revenue) AT (SET year = 2023) FROM sales_v"; + let result = expand_aggregate_with_at(sql); + assert!(result.error.is_none()); + } + + #[test] + fn test_set_string_literal_allows_ungrouped_dimension() { + let sql = + "SELECT year, AGGREGATE(revenue) AT (SET region = 'region') FROM sales_v"; + let result = expand_aggregate_with_at(sql); + assert!(result.error.is_none()); + } + + #[test] + fn test_set_expression_requires_matching_grouping_alias() { + let sql = "SELECT c.year, AGGREGATE(revenue) AT (SET year = year - 1) \ + FROM orders_v o JOIN calendar_v c ON o.year = c.year \ + GROUP BY c.year"; + let result = expand_aggregate_with_at(sql); + assert!(result.error.is_some()); + let message = result.error.unwrap(); + assert!(message.contains("SET year = year - 1")); + } + + #[test] + fn test_set_expression_allows_matching_grouping_alias() { + let sql = "SELECT o.year, AGGREGATE(revenue) AT (SET year = year - 1) FROM sales_v o GROUP BY o.year"; + let result = expand_aggregate_with_at(sql); + assert!(result.error.is_none()); + } + #[test] fn test_extract_agg_function() { assert_eq!(extract_agg_function("SUM(amount)"), "SUM");