Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions yardstick-rs/src/sql/measures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,81 @@ pub fn qualify_outer_reference(expr: &str, table_name: &str, dim: &str) -> Strin
result
}

fn expr_mentions_identifier(expr: &str, ident: &str) -> bool {
let mut chars = expr.chars().peekable();

while let Some(c) = chars.next() {
if c == '\'' {
// Skip SQL single-quoted string literal content, including escaped ''.
while let Some(next) = chars.next() {
if next == '\'' {
if chars.peek() == Some(&'\'') {
chars.next();
} else {
break;
}
}
}
continue;
}

if c.is_alphabetic() || c == '_' {
let mut token = String::from(c);
while let Some(&next) = chars.peek() {
if next.is_alphanumeric() || next == '_' {
token.push(chars.next().unwrap());
Comment on lines +1773 to +1777

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Ignore identifiers inside quoted SET literals

expr_mentions_identifier currently treats any alphabetic run as an identifier, even inside string literals, so a valid constant assignment like AT (SET region = 'region') with no GROUP BY region is now rejected as a self-reference. This introduces false validation errors for legitimate constant SET expressions on text dimensions; the identifier scan needs to skip quoted literal content.

Useful? React with 👍 / 👎.

} else {
break;
}
}

if token.eq_ignore_ascii_case(ident) {
return true;
}
}
}

false
}

fn dimension_in_group_by(
dim: &str,
group_by_cols: &[String],
default_qualifier: Option<&str>,
) -> bool {
let dim_trim = dim.trim();
let dim_name = dim_trim.split('.').next_back().unwrap_or(dim_trim).trim();
let explicit_dim_qualifier = dim_trim
.rsplit_once('.')
.map(|(qualifier, _)| qualifier.trim());
let expected_qualifier = explicit_dim_qualifier.or(default_qualifier);
let dim_lower = dim_trim.to_lowercase();

if dim_trim.contains('(') {
return group_by_cols
.iter()
.any(|col| col.to_lowercase() == dim_lower);
}

group_by_cols.iter().any(|col| {
let col_trim = col.trim();
let col_name = col_trim.split('.').next_back().unwrap_or(col_trim).trim();

if !col_name.eq_ignore_ascii_case(dim_name) {
return false;
}

// When we know the expected outer qualifier, GROUP BY qualifiers must match it.
// Unqualified GROUP BY columns still count.
match (expected_qualifier, col_trim.rsplit_once('.')) {
(Some(expected), Some((col_qualifier, _))) => {
col_qualifier.trim().eq_ignore_ascii_case(expected)
}
_ => true,
}
})
}

/// Qualify column references in a WHERE clause for use inside _inner subquery
/// "region = 'US'" -> "_inner.region = 'US'"
/// "year > 2020 AND region = 'US'" -> "_inner.year > 2020 AND _inner.region = 'US'"
Expand Down Expand Up @@ -3862,6 +3937,32 @@ fn expand_modifiers_to_sql_derived(
}
}

fn validate_set_expression_requirements(
at_patterns: &[(String, Vec<ContextModifier>, usize, usize)],
group_by_cols: &[String],
default_qualifier: Option<&str>,
) -> Option<String> {
for (_, modifiers, _, _) in at_patterns {
for modifier in modifiers {
if let ContextModifier::Set(dim, expr) = modifier {
if dim.contains('(') {
continue;
}
let dim_name = dim.split('.').next_back().unwrap_or(dim).trim();
if expr_mentions_identifier(expr, dim_name)
&& !dimension_in_group_by(dim, group_by_cols, default_qualifier)
{
return Some(format!(
"AT (SET {dim} = {expr}) references {dim_name}, but the query does not group by {dim_name}. Add {dim_name} to SELECT/GROUP BY or use a constant SET value."
));
}
}
}
}

None
}

/// Expand AGGREGATE() with AT modifiers in SQL
pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
let cte_expansion = expand_cte_queries(sql);
Expand Down Expand Up @@ -3911,6 +4012,25 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
// Extract GROUP BY columns for AT (ALL dim) correlation
let group_by_cols = extract_group_by_columns(&sql);

let default_set_qualifier = if let Some(alias) = existing_alias.as_deref() {
Some(alias)
} else if primary_table_name.is_empty() {
None
} else {
Some(primary_table_name.as_str())
};
if let Some(error) = validate_set_expression_requirements(
&at_patterns,
&group_by_cols,
default_set_qualifier,
) {
return AggregateExpandResult {
had_aggregate: true,
expanded_sql: sql.to_string(),
error: Some(error),
};
}

// Extract dimension columns from original SQL for implicit GROUP BY
// (must be done before expansion since expanded SQL has SUM() etc)
let original_dim_cols = extract_dimension_columns_from_select(&sql);
Expand Down Expand Up @@ -4780,6 +4900,59 @@ FROM orders"#;
);
}

#[test]
fn test_set_expression_requires_grouped_dimension() {
let sql =
"SELECT region, AGGREGATE(revenue) AT (SET year = year - 1) FROM sales_v";
let result = expand_aggregate_with_at(sql);
assert!(result.error.is_some());
let message = result.error.unwrap();
assert!(message.contains("SET year = year - 1"));
assert!(message.contains("GROUP BY"));
}

#[test]
fn test_set_expression_allows_grouped_dimension() {
let sql =
"SELECT year, region, AGGREGATE(revenue) AT (SET year = year - 1) FROM sales_v";
let result = expand_aggregate_with_at(sql);
assert!(result.error.is_none());
}

#[test]
fn test_set_constant_allows_ungrouped_dimension() {
let sql =
"SELECT region, AGGREGATE(revenue) AT (SET year = 2023) FROM sales_v";
let result = expand_aggregate_with_at(sql);
assert!(result.error.is_none());
}

#[test]
fn test_set_string_literal_allows_ungrouped_dimension() {
let sql =
"SELECT year, AGGREGATE(revenue) AT (SET region = 'region') FROM sales_v";
let result = expand_aggregate_with_at(sql);
assert!(result.error.is_none());
}

#[test]
fn test_set_expression_requires_matching_grouping_alias() {
let sql = "SELECT c.year, AGGREGATE(revenue) AT (SET year = year - 1) \
FROM orders_v o JOIN calendar_v c ON o.year = c.year \
GROUP BY c.year";
let result = expand_aggregate_with_at(sql);
assert!(result.error.is_some());
let message = result.error.unwrap();
assert!(message.contains("SET year = year - 1"));
}

#[test]
fn test_set_expression_allows_matching_grouping_alias() {
let sql = "SELECT o.year, AGGREGATE(revenue) AT (SET year = year - 1) FROM sales_v o GROUP BY o.year";
let result = expand_aggregate_with_at(sql);
assert!(result.error.is_none());
}

#[test]
fn test_extract_agg_function() {
assert_eq!(extract_agg_function("SUM(amount)"), "SUM");
Expand Down
Loading