diff --git a/asap-planner-rs/src/output/generator.rs b/asap-planner-rs/src/output/generator.rs index 518319e..75f51fd 100644 --- a/asap-planner-rs/src/output/generator.rs +++ b/asap-planner-rs/src/output/generator.rs @@ -112,7 +112,7 @@ pub fn generate_plan( }) } -fn parse_cleanup_policy(s: &str) -> Result { +pub fn parse_cleanup_policy(s: &str) -> Result { match s { "circular_buffer" => Ok(CleanupPolicy::CircularBuffer), "read_based" => Ok(CleanupPolicy::ReadBased), @@ -124,7 +124,7 @@ fn parse_cleanup_policy(s: &str) -> Result { } } -fn key_by_labels_to_yaml(labels: &KeyByLabelNames) -> YamlValue { +pub fn key_by_labels_to_yaml(labels: &KeyByLabelNames) -> YamlValue { YamlValue::Sequence( labels .labels @@ -134,7 +134,134 @@ fn key_by_labels_to_yaml(labels: &KeyByLabelNames) -> YamlValue { ) } -fn params_to_yaml(params: &HashMap) -> YamlValue { +pub fn build_aggregation_entry(id: u32, cfg: &IntermediateAggConfig) -> YamlValue { + let mut map = serde_yaml::Mapping::new(); + map.insert( + YamlValue::String("aggregationId".to_string()), + YamlValue::Number(id.into()), + ); + map.insert( + YamlValue::String("aggregationSubType".to_string()), + YamlValue::String(cfg.aggregation_sub_type.clone()), + ); + map.insert( + YamlValue::String("aggregationType".to_string()), + YamlValue::String(cfg.aggregation_type.clone()), + ); + + let mut labels_map = serde_yaml::Mapping::new(); + labels_map.insert( + YamlValue::String("aggregated".to_string()), + key_by_labels_to_yaml(&cfg.aggregated_labels), + ); + labels_map.insert( + YamlValue::String("grouping".to_string()), + key_by_labels_to_yaml(&cfg.grouping_labels), + ); + labels_map.insert( + YamlValue::String("rollup".to_string()), + key_by_labels_to_yaml(&cfg.rollup_labels), + ); + map.insert( + YamlValue::String("labels".to_string()), + YamlValue::Mapping(labels_map), + ); + + map.insert( + YamlValue::String("metric".to_string()), + YamlValue::String(cfg.metric.clone()), + ); + map.insert( + YamlValue::String("parameters".to_string()), + params_to_yaml(&cfg.parameters), + ); + map.insert( + YamlValue::String("slideInterval".to_string()), + YamlValue::Number(cfg.slide_interval.into()), + ); + map.insert( + YamlValue::String("spatialFilter".to_string()), + YamlValue::String(cfg.spatial_filter.clone()), + ); + map.insert( + YamlValue::String("table_name".to_string()), + match &cfg.table_name { + Some(t) => YamlValue::String(t.clone()), + None => YamlValue::Null, + }, + ); + map.insert( + YamlValue::String("value_column".to_string()), + match &cfg.value_column { + Some(v) => YamlValue::String(v.clone()), + None => YamlValue::Null, + }, + ); + map.insert( + YamlValue::String("windowSize".to_string()), + YamlValue::Number(cfg.window_size.into()), + ); + map.insert( + YamlValue::String("windowType".to_string()), + YamlValue::String(cfg.window_type.clone()), + ); + + YamlValue::Mapping(map) +} + +pub fn build_queries_yaml( + cleanup_policy: CleanupPolicy, + query_keys_map: &IndexMap)>>, + id_map: &HashMap, +) -> Vec { + query_keys_map + .iter() + .map(|(query_str, keys)| { + let aggregations: Vec = keys + .iter() + .map(|(key, cleanup_param)| { + let agg_id = id_map[key]; + let mut agg_map = serde_yaml::Mapping::new(); + agg_map.insert( + YamlValue::String("aggregation_id".to_string()), + YamlValue::Number(agg_id.into()), + ); + if let Some(param) = cleanup_param { + match cleanup_policy { + CleanupPolicy::CircularBuffer => { + agg_map.insert( + YamlValue::String("num_aggregates_to_retain".to_string()), + YamlValue::Number((*param).into()), + ); + } + CleanupPolicy::ReadBased => { + agg_map.insert( + YamlValue::String("read_count_threshold".to_string()), + YamlValue::Number((*param).into()), + ); + } + CleanupPolicy::NoCleanup => {} + } + } + YamlValue::Mapping(agg_map) + }) + .collect(); + + let mut q_map = serde_yaml::Mapping::new(); + q_map.insert( + YamlValue::String("aggregations".to_string()), + YamlValue::Sequence(aggregations), + ); + q_map.insert( + YamlValue::String("query".to_string()), + YamlValue::String(query_str.clone()), + ); + YamlValue::Mapping(q_map) + }) + .collect() +} + +pub fn params_to_yaml(params: &HashMap) -> YamlValue { if params.is_empty() { return YamlValue::Mapping(serde_yaml::Mapping::new()); } @@ -169,82 +296,7 @@ fn build_streaming_yaml( ) -> Result { let aggregations: Vec = dedup_map .iter() - .map(|(key, cfg)| { - let id = id_map[key]; - let mut map = serde_yaml::Mapping::new(); - map.insert( - YamlValue::String("aggregationId".to_string()), - YamlValue::Number(id.into()), - ); - map.insert( - YamlValue::String("aggregationSubType".to_string()), - YamlValue::String(cfg.aggregation_sub_type.clone()), - ); - map.insert( - YamlValue::String("aggregationType".to_string()), - YamlValue::String(cfg.aggregation_type.clone()), - ); - - // labels - let mut labels_map = serde_yaml::Mapping::new(); - labels_map.insert( - YamlValue::String("aggregated".to_string()), - key_by_labels_to_yaml(&cfg.aggregated_labels), - ); - labels_map.insert( - YamlValue::String("grouping".to_string()), - key_by_labels_to_yaml(&cfg.grouping_labels), - ); - labels_map.insert( - YamlValue::String("rollup".to_string()), - key_by_labels_to_yaml(&cfg.rollup_labels), - ); - map.insert( - YamlValue::String("labels".to_string()), - YamlValue::Mapping(labels_map), - ); - - map.insert( - YamlValue::String("metric".to_string()), - YamlValue::String(cfg.metric.clone()), - ); - map.insert( - YamlValue::String("parameters".to_string()), - params_to_yaml(&cfg.parameters), - ); - map.insert( - YamlValue::String("slideInterval".to_string()), - YamlValue::Number(cfg.slide_interval.into()), - ); - map.insert( - YamlValue::String("spatialFilter".to_string()), - YamlValue::String(cfg.spatial_filter.clone()), - ); - map.insert( - YamlValue::String("table_name".to_string()), - match &cfg.table_name { - Some(t) => YamlValue::String(t.clone()), - None => YamlValue::Null, - }, - ); - map.insert( - YamlValue::String("value_column".to_string()), - match &cfg.value_column { - Some(v) => YamlValue::String(v.clone()), - None => YamlValue::Null, - }, - ); - map.insert( - YamlValue::String("windowSize".to_string()), - YamlValue::Number(cfg.window_size.into()), - ); - map.insert( - YamlValue::String("windowType".to_string()), - YamlValue::String(cfg.window_type.clone()), - ); - - YamlValue::Mapping(map) - }) + .map(|(key, cfg)| build_aggregation_entry(id_map[key], cfg)) .collect(); // Build metrics section @@ -282,51 +334,7 @@ fn build_inference_yaml( YamlValue::String(cleanup_policy_str.to_string()), ); - let queries: Vec = query_keys_map - .iter() - .map(|(query_str, keys)| { - let aggregations: Vec = keys - .iter() - .map(|(key, cleanup_param)| { - let agg_id = id_map[key]; - let mut agg_map = serde_yaml::Mapping::new(); - agg_map.insert( - YamlValue::String("aggregation_id".to_string()), - YamlValue::Number(agg_id.into()), - ); - if let Some(param) = cleanup_param { - match cleanup_policy { - CleanupPolicy::CircularBuffer => { - agg_map.insert( - YamlValue::String("num_aggregates_to_retain".to_string()), - YamlValue::Number((*param).into()), - ); - } - CleanupPolicy::ReadBased => { - agg_map.insert( - YamlValue::String("read_count_threshold".to_string()), - YamlValue::Number((*param).into()), - ); - } - CleanupPolicy::NoCleanup => {} - } - } - YamlValue::Mapping(agg_map) - }) - .collect(); - - let mut q_map = serde_yaml::Mapping::new(); - q_map.insert( - YamlValue::String("aggregations".to_string()), - YamlValue::Sequence(aggregations), - ); - q_map.insert( - YamlValue::String("query".to_string()), - YamlValue::String(query_str.clone()), - ); - YamlValue::Mapping(q_map) - }) - .collect(); + let queries = build_queries_yaml(cleanup_policy, query_keys_map, id_map); // Build metrics section let mut metrics_map = serde_yaml::Mapping::new(); diff --git a/asap-planner-rs/src/planner/logics.rs b/asap-planner-rs/src/planner/logics.rs index ca49f95..167a6de 100644 --- a/asap-planner-rs/src/planner/logics.rs +++ b/asap-planner-rs/src/planner/logics.rs @@ -78,10 +78,15 @@ pub struct IntermediateWindowConfig { pub window_type: String, } -pub fn get_precompute_operator_parameters( +/// Shared sketch parameter builder used by both PromQL and SQL paths. +/// +/// `topk_k` is only required for `CountMinSketchWithHeap`: PromQL supplies it +/// from the `topk(k, …)` query argument; SQL passes `None` (SQL never produces +/// this operator today, so the `None` branch is unreachable in practice). +pub fn build_sketch_parameters( aggregation_type: &str, aggregation_sub_type: &str, - match_result: &PromQLMatchResult, + topk_k: Option, sketch_params: Option<&SketchParameterOverrides>, ) -> Result, String> { match aggregation_type { @@ -110,16 +115,8 @@ pub fn get_precompute_operator_parameters( aggregation_sub_type )); } - // Get k from aggregation param - let k: u64 = match_result - .tokens - .get("aggregation") - .and_then(|t| t.aggregation.as_ref()) - .and_then(|a| a.param.as_ref()) - .and_then(|p| p.parse::().ok()) - .map(|f| f as u64) - .ok_or_else(|| "topk query missing required 'k' parameter".to_string())?; - + let k = topk_k + .ok_or_else(|| "CountMinSketchWithHeap requires a topk k value".to_string())?; let depth = sketch_params .and_then(|p| p.count_min_sketch_with_heap.as_ref()) .map(|p| p.depth) @@ -132,7 +129,6 @@ pub fn get_precompute_operator_parameters( .and_then(|p| p.count_min_sketch_with_heap.as_ref()) .and_then(|p| p.heap_multiplier) .unwrap_or(DEFAULT_CMS_HEAP_MULT); - let mut m = HashMap::new(); m.insert("depth".to_string(), serde_json::Value::Number(depth.into())); m.insert("width".to_string(), serde_json::Value::Number(width.into())); @@ -183,6 +179,35 @@ pub fn get_precompute_operator_parameters( } } +/// PromQL wrapper: extracts the topk `k` from the match result when needed, +/// then delegates to `build_sketch_parameters`. +pub fn build_sketch_parameters_from_promql( + aggregation_type: &str, + aggregation_sub_type: &str, + match_result: &PromQLMatchResult, + sketch_params: Option<&SketchParameterOverrides>, +) -> Result, String> { + let topk_k = if aggregation_type == "CountMinSketchWithHeap" { + let k: u64 = match_result + .tokens + .get("aggregation") + .and_then(|t| t.aggregation.as_ref()) + .and_then(|a| a.param.as_ref()) + .and_then(|p| p.parse::().ok()) + .map(|f| f as u64) + .ok_or_else(|| "topk query missing required 'k' parameter".to_string())?; + Some(k) + } else { + None + }; + build_sketch_parameters( + aggregation_type, + aggregation_sub_type, + topk_k, + sketch_params, + ) +} + pub fn get_cleanup_param( cleanup_policy: CleanupPolicy, query_pattern_type: QueryPatternType, @@ -266,6 +291,22 @@ pub fn set_subpopulation_labels( } } +/// SQL cleanup param — SQL queries are always instant (no range_duration/step). +pub fn get_sql_cleanup_param( + cleanup_policy: CleanupPolicy, + t_lookback: u64, + t_repeat: u64, +) -> Result { + match cleanup_policy { + CleanupPolicy::CircularBuffer | CleanupPolicy::ReadBased => { + Ok(t_lookback.div_ceil(t_repeat)) + } + CleanupPolicy::NoCleanup => { + Err("NoCleanup policy should not call get_sql_cleanup_param".to_string()) + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/asap-planner-rs/src/planner/single_query.rs b/asap-planner-rs/src/planner/single_query.rs index 1eee814..00e1d99 100644 --- a/asap-planner-rs/src/planner/single_query.rs +++ b/asap-planner-rs/src/planner/single_query.rs @@ -15,7 +15,7 @@ use std::collections::HashMap; use crate::config::input::SketchParameterOverrides; use crate::error::ControllerError; use crate::planner::logics::{ - get_cleanup_param, get_precompute_operator_parameters, set_subpopulation_labels, + build_sketch_parameters_from_promql, get_cleanup_param, set_subpopulation_labels, set_window_parameters, IntermediateWindowConfig, }; use crate::planner::patterns::build_patterns; @@ -224,13 +224,7 @@ impl SingleQueryProcessor { let statistics = get_statistics_to_compute(pattern_type, &match_result); - let mut configs: Vec = Vec::new(); - - // Shared window config (same for all statistics in this query) let mut window_cfg = IntermediateWindowConfig::default(); - - // We use the first aggregation_type to set window parameters - // (window params don't depend on aggregation_type since sliding is disabled) set_window_parameters( pattern_type, self.t_repeat, @@ -240,72 +234,29 @@ impl SingleQueryProcessor { &mut window_cfg, ); - for statistic in statistics { - let (aggregation_type, aggregation_sub_type) = - map_statistic_to_precompute_operator(statistic, treatment_type) - .map_err(ControllerError::PlannerError)?; - - // Compute labels - let (rollup_labels, grouping_labels, aggregated_labels) = compute_labels( - pattern_type, - statistic, - &aggregation_type, - &match_result, - &all_labels, - ); - - // Main config - let parameters = get_precompute_operator_parameters( - &aggregation_type, - &aggregation_sub_type, - &match_result, - self.sketch_parameters.as_ref(), - ) - .map_err(ControllerError::PlannerError)?; - - // DeltaSetAggregator pairing (hardcoded TODO) - if matches!(aggregation_type.as_str(), "CountMinSketch" | "HydraKLL") { - let delta_params = get_precompute_operator_parameters( - "DeltaSetAggregator", - "", + let (rollup, subpopulation_labels) = + get_label_routing(pattern_type, &match_result, &all_labels); + + let configs = build_agg_configs_for_statistics( + &statistics, + treatment_type, + &subpopulation_labels, + &rollup, + &window_cfg, + &metric, + None, + None, + &spatial_filter, + |agg_type, agg_sub_type| { + build_sketch_parameters_from_promql( + agg_type, + agg_sub_type, &match_result, self.sketch_parameters.as_ref(), ) - .map_err(ControllerError::PlannerError)?; - - configs.push(IntermediateAggConfig { - aggregation_type: "DeltaSetAggregator".to_string(), - aggregation_sub_type: String::new(), - window_type: window_cfg.window_type.clone(), - window_size: window_cfg.window_size, - slide_interval: window_cfg.slide_interval, - spatial_filter: spatial_filter.clone(), - metric: metric.clone(), - table_name: None, - value_column: None, - parameters: delta_params, - rollup_labels: rollup_labels.clone(), - grouping_labels: grouping_labels.clone(), - aggregated_labels: aggregated_labels.clone(), - }); - } - - configs.push(IntermediateAggConfig { - aggregation_type, - aggregation_sub_type, - window_type: window_cfg.window_type.clone(), - window_size: window_cfg.window_size, - slide_interval: window_cfg.slide_interval, - spatial_filter: spatial_filter.clone(), - metric: metric.clone(), - table_name: None, - value_column: None, - parameters, - rollup_labels, - grouping_labels, - aggregated_labels, - }); - } + }, + ) + .map_err(ControllerError::PlannerError)?; // Calculate cleanup param let cleanup_param = if self.cleanup_policy == CleanupPolicy::NoCleanup { @@ -329,29 +280,16 @@ impl SingleQueryProcessor { } } -fn compute_labels( +/// Returns `(rollup, subpopulation_labels)` for a given PromQL pattern type. +/// These are constant across all statistics in a query, so they are computed +/// once before the per-statistic loop. +fn get_label_routing( pattern_type: QueryPatternType, - statistic: Statistic, - aggregation_type: &str, match_result: &PromQLMatchResult, all_labels: &KeyByLabelNames, -) -> (KeyByLabelNames, KeyByLabelNames, KeyByLabelNames) { - let mut rollup; - let mut grouping = KeyByLabelNames::empty(); - let mut aggregated = KeyByLabelNames::empty(); - +) -> (KeyByLabelNames, KeyByLabelNames) { match pattern_type { - QueryPatternType::OnlyTemporal => { - rollup = KeyByLabelNames::empty(); - set_subpopulation_labels( - statistic, - aggregation_type, - all_labels, - &mut rollup, - &mut grouping, - &mut aggregated, - ); - } + QueryPatternType::OnlyTemporal => (KeyByLabelNames::empty(), all_labels.clone()), QueryPatternType::OnlySpatial => { // Match Python: if no by/without modifier, spatial_output = [] (rollup gets all labels). // promql_utilities::get_spatial_aggregation_output_labels has a topk patch that returns @@ -367,47 +305,95 @@ fn compute_labels( } else { KeyByLabelNames::empty() }; - rollup = all_labels.difference(&spatial_output); - set_subpopulation_labels( - statistic, - aggregation_type, - &spatial_output, - &mut rollup, - &mut grouping, - &mut aggregated, - ); + (all_labels.difference(&spatial_output), spatial_output) } QueryPatternType::OneTemporalOneSpatial => { let fn_name = match_result.get_function_name().unwrap_or_default(); let agg_op = match_result.get_aggregation_op().unwrap_or_default(); - let collapsable = get_is_collapsable(&fn_name, &agg_op); - if !collapsable { - rollup = KeyByLabelNames::empty(); - set_subpopulation_labels( - statistic, - aggregation_type, - all_labels, - &mut rollup, - &mut grouping, - &mut aggregated, - ); + if !get_is_collapsable(&fn_name, &agg_op) { + (KeyByLabelNames::empty(), all_labels.clone()) } else { let spatial_output = get_spatial_aggregation_output_labels(match_result, all_labels); - rollup = all_labels.difference(&spatial_output); - set_subpopulation_labels( - statistic, - aggregation_type, - &spatial_output, - &mut rollup, - &mut grouping, - &mut aggregated, - ); + (all_labels.difference(&spatial_output), spatial_output) } } } +} + +/// Shared per-statistic config builder used by both PromQL and SQL paths. +/// +/// `get_params(agg_type, agg_sub_type)` is a closure supplied by the caller +/// that resolves sketch parameters; it is the only thing that differs between +/// the two paths. +#[allow(clippy::too_many_arguments)] +pub fn build_agg_configs_for_statistics( + statistics: &[Statistic], + treatment_type: QueryTreatmentType, + subpopulation_labels: &KeyByLabelNames, + rollup: &KeyByLabelNames, + window_cfg: &IntermediateWindowConfig, + metric: &str, + table_name: Option<&str>, + value_column: Option<&str>, + spatial_filter: &str, + get_params: impl Fn(&str, &str) -> Result, String>, +) -> Result, String> { + let mut configs = Vec::new(); + + for statistic in statistics.iter().copied() { + let (agg_type, agg_sub_type) = + map_statistic_to_precompute_operator(statistic, treatment_type)?; + + let mut grouping = KeyByLabelNames::empty(); + let mut aggregated = KeyByLabelNames::empty(); + set_subpopulation_labels( + statistic, + &agg_type, + subpopulation_labels, + &mut rollup.clone(), + &mut grouping, + &mut aggregated, + ); + + if matches!(agg_type.as_str(), "CountMinSketch" | "HydraKLL") { + let delta_params = get_params("DeltaSetAggregator", "")?; + configs.push(IntermediateAggConfig { + aggregation_type: "DeltaSetAggregator".to_string(), + aggregation_sub_type: String::new(), + window_type: window_cfg.window_type.clone(), + window_size: window_cfg.window_size, + slide_interval: window_cfg.slide_interval, + spatial_filter: spatial_filter.to_string(), + metric: metric.to_string(), + table_name: table_name.map(str::to_string), + value_column: value_column.map(str::to_string), + parameters: delta_params, + rollup_labels: rollup.clone(), + grouping_labels: grouping.clone(), + aggregated_labels: aggregated.clone(), + }); + } + + let parameters = get_params(&agg_type, &agg_sub_type)?; + configs.push(IntermediateAggConfig { + aggregation_type: agg_type, + aggregation_sub_type: agg_sub_type, + window_type: window_cfg.window_type.clone(), + window_size: window_cfg.window_size, + slide_interval: window_cfg.slide_interval, + spatial_filter: spatial_filter.to_string(), + metric: metric.to_string(), + table_name: table_name.map(str::to_string), + value_column: value_column.map(str::to_string), + parameters, + rollup_labels: rollup.clone(), + grouping_labels: grouping, + aggregated_labels: aggregated, + }); + } - (rollup, grouping, aggregated) + Ok(configs) } #[cfg(test)]