diff --git a/docs/src/guide/querying.md b/docs/src/guide/querying.md index 70f6bc0..718840f 100644 --- a/docs/src/guide/querying.md +++ b/docs/src/guide/querying.md @@ -42,7 +42,7 @@ For deeper insights, use `firm query` which supports a SQL-like query language. ### Query syntax ``` -from | | | ... +from | | | ... | ``` ### Available operations @@ -53,6 +53,16 @@ from | | | ... - `order [asc|desc]` - Sort results - `limit ` - Limit the number of results +### Aggregations + +An optional final clause that summarizes the result set: + +- `select , ...` - Extract specific field values +- `count []` - Count entities (optionally only those with the field) +- `sum ` - Sum a numeric field +- `average ` - Compute the mean of a numeric field +- `median ` - Compute the median of a numeric field + ### Examples **Find all incomplete tasks:** @@ -75,6 +85,21 @@ $ firm query 'from invoice | where status == "draft" or status == "sent"' $ firm query 'from project | where status == "in progress" | related(2) task | where is_completed == false | where due_date > 2025-01-01 | order due_date | limit 10' ``` +**Count incomplete tasks:** +```bash +$ firm query 'from task | where is_completed == false | count' +``` + +**Sum invoice amounts:** +```bash +$ firm query 'from invoice | where status == "sent" | sum amount' +``` + +**Extract specific fields:** +```bash +$ firm query 'from task | where is_completed == false | select @id, name, due_date' +``` + ### Query operators You can filter by any field or metadata (`@type`, `@id`), traverse relationships multiple degrees deep, and compose operations to build the exact query you need. diff --git a/docs/src/reference/query-reference.md b/docs/src/reference/query-reference.md index 4607d0b..45a0508 100644 --- a/docs/src/reference/query-reference.md +++ b/docs/src/reference/query-reference.md @@ -18,15 +18,17 @@ Firm queries always operate on a "bag of entities". At every stage in query exec The `from` clause selects the initial set of entities, and every subsequent operation filters, expands, limits, or orders that entity set. This keeps the query language simple and focused on navigating the entity graph. +Optionally, a final **aggregation** clause can be added at the end of a query to compute a summary value (like a count or sum) or extract specific fields from the final entity set. Aggregations are the only operation that transforms the result from entities into a different shape. + ## Basic syntax All queries follow this structure: ``` -from | | | ... +from | | | ... | ``` -Start with a `from` clause, then chain operations using the pipe symbol `|`. +Start with a `from` clause, then chain operations using the pipe symbol `|`. Optionally end with an aggregation clause. ## Entity selector @@ -200,6 +202,85 @@ from task | where priority > 8 | order priority desc | limit 5 **Syntax:** `limit ` +## Aggregations + +Aggregations are optional clauses that go at the end of a query. They transform the entity set into a summary value or extracted fields. Only one aggregation can be used per query. + +### select + +Extract specific field values from entities: + +```bash +# Select a single field +from person | select name + +# Select multiple fields +from task | select name, status, due_date + +# Include metadata fields +from task | where is_completed == false | select @id, name, due_date +``` + +**Syntax:** `select , , ...` + +Fields can be regular field names or metadata fields (`@id`, `@type`). Missing fields appear as empty values. + +### count + +Count entities, optionally filtering by field presence: + +```bash +# Count all matching entities +from task | where is_completed == false | count + +# Count entities that have a specific field +from person | count email +``` + +**Syntax:** +- `count` - Count all entities in the result set +- `count ` - Count entities that have the specified field + +### sum + +Sum numeric field values across entities: + +```bash +# Sum integer or float fields +from line_item | sum quantity + +# Sum currency fields +from invoice | where status == "sent" | sum amount +``` + +**Syntax:** `sum ` + +Works with integer, float, and currency fields. Entities missing the field are skipped. Currency values must all share the same currency code — mixed currencies produce an error. + +### average + +Compute the mean of a numeric field: + +```bash +from task | average estimated_hours +``` + +**Syntax:** `average ` + +Works with integer, float, and currency fields. Entities missing the field are skipped. Returns an error if no entities have the field. + +### median + +Compute the median of a numeric field: + +```bash +from task | median estimated_hours +``` + +**Syntax:** `median ` + +Works with integer, float, and currency fields. Entities missing the field are skipped. For an even number of values, returns the average of the two middle values. Returns an error if no entities have the field. + ## Examples ### Find incomplete tasks @@ -226,6 +307,24 @@ from opportunity | where value >= 10000.00 USD | order value desc from project | where status == "active" | related task ``` +### Count incomplete tasks + +```bash +from task | where is_completed == false | count +``` + +### Total invoice amount + +```bash +from invoice | where status == "sent" | sum amount +``` + +### Task summary with select + +```bash +from task | where is_completed == false | order due_date | select @id, name, due_date +``` + ### Complex multi-hop query ```bash @@ -244,11 +343,19 @@ This query: Queries are executed left to right, with each operation transforming the result set: ``` -from task → [all tasks] -| where status → [filtered tasks] -| related project → [related projects] -| order name → [sorted projects] -| limit 5 → [top 5 projects] +from task → [all tasks] +| where is_completed == false → [filtered tasks] +| related project → [related projects] +| order name → [sorted projects] +| limit 5 → [top 5 projects] ``` Each operation receives the output of the previous operation and produces a new result set. + +If an aggregation is present, it runs last and transforms the entity set into a result value: + +``` +from invoice → [all invoices] +| where status == "sent" → [filtered invoices] +| sum amount → 15000.00 USD +``` diff --git a/firm_cli/src/commands/query.rs b/firm_cli/src/commands/query.rs index d08530e..490fe2c 100644 --- a/firm_cli/src/commands/query.rs +++ b/firm_cli/src/commands/query.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use firm_core::graph::Query; +use firm_core::graph::{Query, QueryResult}; use firm_lang::parser::query::parse_query; use crate::errors::CliError; @@ -30,21 +30,24 @@ pub fn query_entities( // Execute the query ui::debug("Executing query"); - let results = query.execute(&graph).map_err(|e| { + let result = query.execute(&graph).map_err(|e| { ui::error(&format!("Query execution failed: {}", e)); CliError::QueryError })?; - ui::success(&format!("Query returned {} entities", results.len())); - // Output results - match output_format { - OutputFormat::Pretty => { - ui::pretty_output_entity_list(&results); - } - OutputFormat::Json => { - ui::json_output(&results); + match result { + QueryResult::Entities(entities) => { + ui::success(&format!("Query returned {} entities", entities.len())); + match output_format { + OutputFormat::Pretty => ui::pretty_output_entity_list(&entities), + OutputFormat::Json => ui::json_output(&entities), + } } + QueryResult::Aggregation(agg_result) => match output_format { + OutputFormat::Pretty => ui::raw_output(&agg_result.to_string()), + OutputFormat::Json => ui::json_output(&agg_result), + }, } Ok(()) diff --git a/firm_core/src/graph/query/aggregation/average.rs b/firm_core/src/graph/query/aggregation/average.rs new file mode 100644 index 0000000..58afa53 --- /dev/null +++ b/firm_core/src/graph/query/aggregation/average.rs @@ -0,0 +1,113 @@ +//! Average aggregation: compute the mean of a numeric field + +use super::super::filter::FieldRef; +use super::super::types::AggregationResult; +use super::super::QueryError; +use super::{collect_numeric_values, require_regular_field}; +use crate::Entity; + +pub fn execute( + field: &FieldRef, + entities: &[&Entity], +) -> Result { + let field_id = require_regular_field(field, "average")?; + let values = collect_numeric_values(field_id, entities)?; + + if values.is_empty() { + return Err(QueryError::InvalidAggregation { + message: "Cannot compute average of empty result set".to_string(), + }); + } + + let sum: f64 = values.iter().map(|v| v.as_f64()).sum(); + let avg = sum / values.len() as f64; + + Ok(AggregationResult::Average(avg)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Entity, EntityId, EntityType, FieldId, FieldValue}; + + #[test] + fn test_average_integers() { + let entities = vec![ + Entity::new(EntityId::new("a"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(10)), + Entity::new(EntityId::new("b"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(20)), + Entity::new(EntityId::new("c"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(30)), + ]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs).unwrap(); + assert_eq!(result, AggregationResult::Average(20.0)); + } + + #[test] + fn test_average_floats() { + let entities = vec![ + Entity::new(EntityId::new("a"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Float(1.0)), + Entity::new(EntityId::new("b"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Float(2.0)), + ]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs).unwrap(); + assert_eq!(result, AggregationResult::Average(1.5)); + } + + #[test] + fn test_average_mixed_integer_and_float() { + let entities = vec![ + Entity::new(EntityId::new("a"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(10)), + Entity::new(EntityId::new("b"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Float(20.0)), + ]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs).unwrap(); + assert_eq!(result, AggregationResult::Average(15.0)); + } + + #[test] + fn test_average_skips_missing_fields() { + let entities = vec![ + Entity::new(EntityId::new("a"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(10)), + Entity::new(EntityId::new("b"), EntityType::new("item")), + ]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs).unwrap(); + // Only 1 entity has the field, so average = 10/1 + assert_eq!(result, AggregationResult::Average(10.0)); + } + + #[test] + fn test_average_empty_result_set_error() { + let refs: Vec<&Entity> = vec![]; + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs); + assert!(matches!( + result, + Err(QueryError::InvalidAggregation { .. }) + )); + } + + #[test] + fn test_average_no_entities_with_field_error() { + let entities = vec![Entity::new(EntityId::new("a"), EntityType::new("item"))]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("nonexistent")); + let result = execute(&field, &refs); + assert!(matches!( + result, + Err(QueryError::InvalidAggregation { .. }) + )); + } +} diff --git a/firm_core/src/graph/query/aggregation/count.rs b/firm_core/src/graph/query/aggregation/count.rs new file mode 100644 index 0000000..a03db9c --- /dev/null +++ b/firm_core/src/graph/query/aggregation/count.rs @@ -0,0 +1,90 @@ +//! Count aggregation: count entities, optionally filtering by field presence + +use super::super::filter::FieldRef; +use super::super::types::AggregationResult; +use super::super::QueryError; +use crate::Entity; + +pub fn execute( + field: Option<&FieldRef>, + entities: &[&Entity], +) -> Result { + let count = match field { + None => entities.len(), + Some(FieldRef::Metadata(_)) => entities.len(), + Some(FieldRef::Regular(field_id)) => entities + .iter() + .filter(|e| e.get_field(field_id).is_some()) + .count(), + }; + Ok(AggregationResult::Count(count)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Entity, EntityId, EntityType, FieldId, FieldValue}; + use super::super::super::filter::MetadataField; + + fn make_entities() -> Vec { + vec![ + Entity::new(EntityId::new("p1"), EntityType::new("person")) + .with_field(FieldId::new("name"), "Alice") + .with_field(FieldId::new("age"), FieldValue::Integer(30)), + Entity::new(EntityId::new("p2"), EntityType::new("person")) + .with_field(FieldId::new("name"), "Bob"), + // p2 has no "age" field + ] + } + + #[test] + fn test_count_all() { + let entities = make_entities(); + let refs: Vec<&Entity> = entities.iter().collect(); + let result = execute(None, &refs).unwrap(); + assert_eq!(result, AggregationResult::Count(2)); + } + + #[test] + fn test_count_with_present_field() { + let entities = make_entities(); + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("name")); + let result = execute(Some(&field), &refs).unwrap(); + assert_eq!(result, AggregationResult::Count(2)); + } + + #[test] + fn test_count_with_partial_field() { + let entities = make_entities(); + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("age")); + let result = execute(Some(&field), &refs).unwrap(); + assert_eq!(result, AggregationResult::Count(1)); + } + + #[test] + fn test_count_with_missing_field() { + let entities = make_entities(); + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("nonexistent")); + let result = execute(Some(&field), &refs).unwrap(); + assert_eq!(result, AggregationResult::Count(0)); + } + + #[test] + fn test_count_metadata_field_counts_all() { + let entities = make_entities(); + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Metadata(MetadataField::Id); + let result = execute(Some(&field), &refs).unwrap(); + assert_eq!(result, AggregationResult::Count(2)); + } + + #[test] + fn test_count_empty_set() { + let refs: Vec<&Entity> = vec![]; + let result = execute(None, &refs).unwrap(); + assert_eq!(result, AggregationResult::Count(0)); + } +} diff --git a/firm_core/src/graph/query/aggregation/median.rs b/firm_core/src/graph/query/aggregation/median.rs new file mode 100644 index 0000000..f52bed6 --- /dev/null +++ b/firm_core/src/graph/query/aggregation/median.rs @@ -0,0 +1,122 @@ +//! Median aggregation: compute the median of a numeric field + +use super::super::filter::FieldRef; +use super::super::types::AggregationResult; +use super::super::QueryError; +use super::{collect_numeric_values, require_regular_field}; +use crate::Entity; + +pub fn execute( + field: &FieldRef, + entities: &[&Entity], +) -> Result { + let field_id = require_regular_field(field, "median")?; + let values = collect_numeric_values(field_id, entities)?; + + if values.is_empty() { + return Err(QueryError::InvalidAggregation { + message: "Cannot compute median of empty result set".to_string(), + }); + } + + let mut float_values: Vec = values.iter().map(|v| v.as_f64()).collect(); + float_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + + let len = float_values.len(); + let median = if len % 2 == 0 { + (float_values[len / 2 - 1] + float_values[len / 2]) / 2.0 + } else { + float_values[len / 2] + }; + + Ok(AggregationResult::Median(median)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Entity, EntityId, EntityType, FieldId, FieldValue}; + + #[test] + fn test_median_odd_count() { + let entities = vec![ + Entity::new(EntityId::new("a"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(10)), + Entity::new(EntityId::new("b"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(20)), + Entity::new(EntityId::new("c"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(30)), + ]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs).unwrap(); + assert_eq!(result, AggregationResult::Median(20.0)); + } + + #[test] + fn test_median_mixed_integer_and_float() { + let entities = vec![ + Entity::new(EntityId::new("a"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(10)), + Entity::new(EntityId::new("b"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Float(20.0)), + Entity::new(EntityId::new("c"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(30)), + ]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs).unwrap(); + assert_eq!(result, AggregationResult::Median(20.0)); + } + + #[test] + fn test_median_even_count() { + let entities = vec![ + Entity::new(EntityId::new("a"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(10)), + Entity::new(EntityId::new("b"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(20)), + ]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs).unwrap(); + assert_eq!(result, AggregationResult::Median(15.0)); + } + + #[test] + fn test_median_single_value() { + let entities = vec![Entity::new(EntityId::new("a"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(42))]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs).unwrap(); + assert_eq!(result, AggregationResult::Median(42.0)); + } + + #[test] + fn test_median_unsorted_input() { + let entities = vec![ + Entity::new(EntityId::new("a"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(30)), + Entity::new(EntityId::new("b"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(10)), + Entity::new(EntityId::new("c"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(20)), + ]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs).unwrap(); + assert_eq!(result, AggregationResult::Median(20.0)); + } + + #[test] + fn test_median_empty_error() { + let refs: Vec<&Entity> = vec![]; + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs); + assert!(matches!( + result, + Err(QueryError::InvalidAggregation { .. }) + )); + } +} diff --git a/firm_core/src/graph/query/aggregation/mod.rs b/firm_core/src/graph/query/aggregation/mod.rs new file mode 100644 index 0000000..c68d5e2 --- /dev/null +++ b/firm_core/src/graph/query/aggregation/mod.rs @@ -0,0 +1,143 @@ +//! Aggregation execution logic for queries + +mod average; +mod count; +mod median; +mod select; +mod sum; + +use super::filter::FieldRef; +use super::types::{Aggregation, AggregationResult}; +use super::QueryError; +use crate::Entity; + +impl Aggregation { + /// Execute this aggregation over a set of entities + pub fn execute(&self, entities: &[&Entity]) -> Result { + match self { + Aggregation::Select(fields) => select::execute(fields, entities), + Aggregation::Count(field) => count::execute(field.as_ref(), entities), + Aggregation::Sum(field) => sum::execute(field, entities), + Aggregation::Average(field) => average::execute(field, entities), + Aggregation::Median(field) => median::execute(field, entities), + } + } +} + +/// Require that the field is a regular field (not metadata) for numeric aggregations. +fn require_regular_field<'a>( + field: &'a FieldRef, + operation: &str, +) -> Result<&'a crate::FieldId, QueryError> { + match field { + FieldRef::Regular(id) => Ok(id), + FieldRef::Metadata(_) => Err(QueryError::InvalidAggregation { + message: format!( + "Cannot {} a metadata field. Use a regular numeric field.", + operation + ), + }), + } +} + +/// Internal representation of a numeric value extracted from an entity field. +#[derive(Debug, Clone)] +enum NumericValue { + Integer(i64), + Float(f64), + Currency { + amount: rust_decimal::Decimal, + currency: iso_currency::Currency, + }, +} + +impl NumericValue { + fn as_f64(&self) -> f64 { + match self { + NumericValue::Integer(i) => *i as f64, + NumericValue::Float(f) => *f, + NumericValue::Currency { amount, .. } => { + use rust_decimal::prelude::ToPrimitive; + amount.to_f64().unwrap_or(0.0) + } + } + } +} + +/// Classifies the dominant numeric type across a set of values. +enum NumericType { + Integer, + Float, + Currency(iso_currency::Currency), +} + +/// Classify what numeric type a set of values represents, handling mixed int/float promotion. +fn classify_numeric_type(values: &[NumericValue]) -> Result { + let mut has_integer = false; + let mut has_float = false; + let mut currency: Option = None; + + for v in values { + match v { + NumericValue::Integer(_) => has_integer = true, + NumericValue::Float(_) => has_float = true, + NumericValue::Currency { currency: c, .. } => { + currency = Some(*c); + } + } + } + + let has_currency = currency.is_some(); + + if has_currency && (has_integer || has_float) { + return Err(QueryError::InvalidAggregation { + message: "Cannot mix currency and numeric values in aggregation".to_string(), + }); + } + + if has_currency { + Ok(NumericType::Currency(currency.unwrap())) + } else if has_float { + Ok(NumericType::Float) + } else { + Ok(NumericType::Integer) + } +} + +/// Collect numeric values from entities for a given field, skipping entities that lack the field. +fn collect_numeric_values( + field_id: &crate::FieldId, + entities: &[&Entity], +) -> Result, QueryError> { + let mut values = Vec::new(); + + for entity in entities { + if let Some(field_value) = entity.get_field(field_id) { + match field_value { + crate::FieldValue::Integer(i) => { + values.push(NumericValue::Integer(*i)); + } + crate::FieldValue::Float(f) => { + values.push(NumericValue::Float(*f)); + } + crate::FieldValue::Currency { amount, currency } => { + values.push(NumericValue::Currency { + amount: *amount, + currency: *currency, + }); + } + other => { + return Err(QueryError::InvalidAggregation { + message: format!( + "Cannot aggregate non-numeric field '{}'. Found type: {}", + field_id.as_str(), + other.get_type() + ), + }); + } + } + } + } + + Ok(values) +} diff --git a/firm_core/src/graph/query/aggregation/select.rs b/firm_core/src/graph/query/aggregation/select.rs new file mode 100644 index 0000000..ebea243 --- /dev/null +++ b/firm_core/src/graph/query/aggregation/select.rs @@ -0,0 +1,151 @@ +//! Select aggregation: extract specific field values from entities + +use super::super::filter::{FieldRef, MetadataField}; +use super::super::types::AggregationResult; +use super::super::QueryError; +use crate::{Entity, FieldValue}; + +pub fn execute( + fields: &[FieldRef], + entities: &[&Entity], +) -> Result { + let columns: Vec = fields + .iter() + .map(|f| match f { + FieldRef::Metadata(MetadataField::Id) => "@id".to_string(), + FieldRef::Metadata(MetadataField::Type) => "@type".to_string(), + FieldRef::Regular(field_id) => field_id.as_str().to_string(), + }) + .collect(); + + let rows: Vec>> = entities + .iter() + .map(|entity| { + fields + .iter() + .map(|field| match field { + FieldRef::Metadata(MetadataField::Id) => { + Some(FieldValue::String(entity.id.to_string())) + } + FieldRef::Metadata(MetadataField::Type) => { + Some(FieldValue::String(entity.entity_type.to_string())) + } + FieldRef::Regular(field_id) => entity.get_field(field_id).cloned(), + }) + .collect() + }) + .collect(); + + Ok(AggregationResult::Select { columns, rows }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Entity, EntityId, EntityType, FieldId, FieldValue}; + + fn make_entities() -> Vec { + vec![ + Entity::new(EntityId::new("p1"), EntityType::new("person")) + .with_field(FieldId::new("name"), "Alice") + .with_field(FieldId::new("age"), FieldValue::Integer(30)), + Entity::new(EntityId::new("p2"), EntityType::new("person")) + .with_field(FieldId::new("name"), "Bob"), + ] + } + + #[test] + fn test_select_single_field() { + let entities = make_entities(); + let refs: Vec<&Entity> = entities.iter().collect(); + let fields = vec![FieldRef::Regular(FieldId::new("name"))]; + let result = execute(&fields, &refs).unwrap(); + if let AggregationResult::Select { columns, rows } = result { + assert_eq!(columns, vec!["name"]); + assert_eq!(rows.len(), 2); + assert_eq!(rows[0][0], Some(FieldValue::String("Alice".to_string()))); + assert_eq!(rows[1][0], Some(FieldValue::String("Bob".to_string()))); + } else { + panic!("Expected Select result"); + } + } + + #[test] + fn test_select_multiple_fields() { + let entities = make_entities(); + let refs: Vec<&Entity> = entities.iter().collect(); + let fields = vec![ + FieldRef::Regular(FieldId::new("name")), + FieldRef::Regular(FieldId::new("age")), + ]; + let result = execute(&fields, &refs).unwrap(); + if let AggregationResult::Select { columns, rows } = result { + assert_eq!(columns, vec!["name", "age"]); + assert_eq!(rows.len(), 2); + // p1 has both fields + assert_eq!(rows[0][0], Some(FieldValue::String("Alice".to_string()))); + assert_eq!(rows[0][1], Some(FieldValue::Integer(30))); + // p2 has name but no age + assert_eq!(rows[1][0], Some(FieldValue::String("Bob".to_string()))); + assert_eq!(rows[1][1], None); + } else { + panic!("Expected Select result"); + } + } + + #[test] + fn test_select_metadata_id() { + let entities = make_entities(); + let refs: Vec<&Entity> = entities.iter().collect(); + let fields = vec![FieldRef::Metadata(MetadataField::Id)]; + let result = execute(&fields, &refs).unwrap(); + if let AggregationResult::Select { columns, rows } = result { + assert_eq!(columns, vec!["@id"]); + // EntityId converts to snake_case, so "p1" becomes "p_1" + assert_eq!(rows[0][0], Some(FieldValue::String("p_1".to_string()))); + assert_eq!(rows[1][0], Some(FieldValue::String("p_2".to_string()))); + } else { + panic!("Expected Select result"); + } + } + + #[test] + fn test_select_metadata_type() { + let entities = make_entities(); + let refs: Vec<&Entity> = entities.iter().collect(); + let fields = vec![FieldRef::Metadata(MetadataField::Type)]; + let result = execute(&fields, &refs).unwrap(); + if let AggregationResult::Select { columns, rows } = result { + assert_eq!(columns, vec!["@type"]); + assert!(rows.iter().all(|r| r[0] == Some(FieldValue::String("person".to_string())))); + } else { + panic!("Expected Select result"); + } + } + + #[test] + fn test_select_missing_field_returns_none() { + let entities = make_entities(); + let refs: Vec<&Entity> = entities.iter().collect(); + let fields = vec![FieldRef::Regular(FieldId::new("nonexistent"))]; + let result = execute(&fields, &refs).unwrap(); + if let AggregationResult::Select { rows, .. } = result { + assert!(rows.iter().all(|r| r[0].is_none())); + } else { + panic!("Expected Select result"); + } + } + + #[test] + fn test_select_empty_entities() { + let refs: Vec<&Entity> = vec![]; + let fields = vec![FieldRef::Regular(FieldId::new("name"))]; + let result = execute(&fields, &refs).unwrap(); + if let AggregationResult::Select { columns, rows } = result { + assert_eq!(columns, vec!["name"]); + assert!(rows.is_empty()); + } else { + panic!("Expected Select result"); + } + } +} diff --git a/firm_core/src/graph/query/aggregation/sum.rs b/firm_core/src/graph/query/aggregation/sum.rs new file mode 100644 index 0000000..66adc88 --- /dev/null +++ b/firm_core/src/graph/query/aggregation/sum.rs @@ -0,0 +1,227 @@ +//! Sum aggregation: sum numeric field values across entities + +use super::super::filter::FieldRef; +use super::super::types::{AggregateValue, AggregationResult}; +use super::super::QueryError; +use super::{NumericType, NumericValue, collect_numeric_values, classify_numeric_type, require_regular_field}; +use crate::Entity; + +pub fn execute( + field: &FieldRef, + entities: &[&Entity], +) -> Result { + let field_id = require_regular_field(field, "sum")?; + let values = collect_numeric_values(field_id, entities)?; + + if values.is_empty() { + return Ok(AggregationResult::Sum(AggregateValue::Integer(0))); + } + + match classify_numeric_type(&values)? { + NumericType::Integer => { + let sum: i64 = values + .iter() + .map(|v| match v { + NumericValue::Integer(i) => *i, + _ => 0, + }) + .sum(); + Ok(AggregationResult::Sum(AggregateValue::Integer(sum))) + } + NumericType::Float => { + let sum: f64 = values.iter().map(|v| v.as_f64()).sum(); + Ok(AggregationResult::Sum(AggregateValue::Float(sum))) + } + NumericType::Currency(expected_currency) => { + let mut total = rust_decimal::Decimal::ZERO; + for v in &values { + match v { + NumericValue::Currency { amount, currency } => { + if currency.code() != expected_currency.code() { + return Err(QueryError::InvalidAggregation { + message: format!( + "Cannot sum mixed currencies (found {}, {}). \ + Filter first, e.g.: where {} >= 0 {}", + expected_currency.code(), + currency.code(), + field_id.as_str(), + expected_currency.code(), + ), + }); + } + total += amount; + } + _ => unreachable!(), + } + } + Ok(AggregationResult::Sum(AggregateValue::Currency { + amount: total, + currency: expected_currency, + })) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Entity, EntityId, EntityType, FieldId, FieldValue}; + use iso_currency::Currency; + use rust_decimal::Decimal; + + fn make_integer_entities() -> Vec { + vec![ + Entity::new(EntityId::new("a"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(10)), + Entity::new(EntityId::new("b"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(20)), + Entity::new(EntityId::new("c"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(30)), + ] + } + + #[test] + fn test_sum_integers() { + let entities = make_integer_entities(); + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs).unwrap(); + assert_eq!(result, AggregationResult::Sum(AggregateValue::Integer(60))); + } + + #[test] + fn test_sum_floats() { + let entities = vec![ + Entity::new(EntityId::new("a"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Float(1.5)), + Entity::new(EntityId::new("b"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Float(2.5)), + ]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs).unwrap(); + assert_eq!(result, AggregationResult::Sum(AggregateValue::Float(4.0))); + } + + #[test] + fn test_sum_mixed_integer_and_float() { + let entities = vec![ + Entity::new(EntityId::new("a"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(10)), + Entity::new(EntityId::new("b"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Float(2.5)), + ]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs).unwrap(); + assert_eq!(result, AggregationResult::Sum(AggregateValue::Float(12.5))); + } + + #[test] + fn test_sum_currency_same_code() { + let entities = vec![ + Entity::new(EntityId::new("a"), EntityType::new("invoice")) + .with_field( + FieldId::new("amount"), + FieldValue::Currency { + amount: Decimal::new(10000, 2), + currency: Currency::USD, + }, + ), + Entity::new(EntityId::new("b"), EntityType::new("invoice")) + .with_field( + FieldId::new("amount"), + FieldValue::Currency { + amount: Decimal::new(5000, 2), + currency: Currency::USD, + }, + ), + ]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("amount")); + let result = execute(&field, &refs).unwrap(); + assert_eq!( + result, + AggregationResult::Sum(AggregateValue::Currency { + amount: Decimal::new(15000, 2), + currency: Currency::USD, + }) + ); + } + + #[test] + fn test_sum_currency_mixed_codes_error() { + let entities = vec![ + Entity::new(EntityId::new("a"), EntityType::new("invoice")) + .with_field( + FieldId::new("amount"), + FieldValue::Currency { + amount: Decimal::new(100, 0), + currency: Currency::USD, + }, + ), + Entity::new(EntityId::new("b"), EntityType::new("invoice")) + .with_field( + FieldId::new("amount"), + FieldValue::Currency { + amount: Decimal::new(200, 0), + currency: Currency::EUR, + }, + ), + ]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("amount")); + let result = execute(&field, &refs); + assert!(matches!( + result, + Err(QueryError::InvalidAggregation { .. }) + )); + } + + #[test] + fn test_sum_non_numeric_error() { + let entities = vec![Entity::new(EntityId::new("a"), EntityType::new("item")) + .with_field(FieldId::new("name"), "hello")]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("name")); + let result = execute(&field, &refs); + assert!(matches!( + result, + Err(QueryError::InvalidAggregation { .. }) + )); + } + + #[test] + fn test_sum_empty_set() { + let refs: Vec<&Entity> = vec![]; + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs).unwrap(); + assert_eq!(result, AggregationResult::Sum(AggregateValue::Integer(0))); + } + + #[test] + fn test_sum_skips_missing_fields() { + let entities = vec![ + Entity::new(EntityId::new("a"), EntityType::new("item")) + .with_field(FieldId::new("val"), FieldValue::Integer(10)), + Entity::new(EntityId::new("b"), EntityType::new("item")), + // b has no "val" field + ]; + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Regular(FieldId::new("val")); + let result = execute(&field, &refs).unwrap(); + assert_eq!(result, AggregationResult::Sum(AggregateValue::Integer(10))); + } + + #[test] + fn test_sum_metadata_field_error() { + let entities = make_integer_entities(); + let refs: Vec<&Entity> = entities.iter().collect(); + let field = FieldRef::Metadata(super::super::super::filter::MetadataField::Id); + let result = execute(&field, &refs); + assert!(matches!( + result, + Err(QueryError::InvalidAggregation { .. }) + )); + } +} diff --git a/firm_core/src/graph/query/mod.rs b/firm_core/src/graph/query/mod.rs index 57b94d1..7785552 100644 --- a/firm_core/src/graph/query/mod.rs +++ b/firm_core/src/graph/query/mod.rs @@ -5,6 +5,7 @@ //! - Query operations (where, related, order, limit) //! - Query execution against the entity graph +mod aggregation; mod filter; mod order; mod query_errors; diff --git a/firm_core/src/graph/query/query_errors.rs b/firm_core/src/graph/query/query_errors.rs index a4b54ba..0e47f2f 100644 --- a/firm_core/src/graph/query/query_errors.rs +++ b/firm_core/src/graph/query/query_errors.rs @@ -25,6 +25,10 @@ pub enum QueryError { InvalidDateFormat { value: String, }, + /// Invalid aggregation operation + InvalidAggregation { + message: String, + }, } impl fmt::Display for QueryError { @@ -75,6 +79,9 @@ impl fmt::Display for QueryError { value ) } + QueryError::InvalidAggregation { message } => { + write!(f, "Invalid aggregation: {}", message) + } } } } diff --git a/firm_core/src/graph/query/types.rs b/firm_core/src/graph/query/types.rs index dc58b79..8ce75ec 100644 --- a/firm_core/src/graph/query/types.rs +++ b/firm_core/src/graph/query/types.rs @@ -1,9 +1,15 @@ //! Core query types for executing queries against the entity graph +use std::fmt; + +use iso_currency::Currency; +use rust_decimal::Decimal; +use serde::Serialize; + use super::QueryError; -use super::filter::CompoundFilterCondition; +use super::filter::{CompoundFilterCondition, FieldRef}; use super::order::compare_entities_by_field; -use crate::{Entity, EntityType}; +use crate::{Entity, EntityType, FieldValue}; /// Sort direction #[derive(Debug, Clone, PartialEq)] @@ -15,11 +21,102 @@ pub enum SortDirection { } +/// Terminal aggregation that transforms the query result set +#[derive(Debug, Clone)] +pub enum Aggregation { + /// Select specific field values from entities + Select(Vec), + /// Count entities (None = count all, Some = count entities with field) + Count(Option), + /// Sum a numeric field + Sum(FieldRef), + /// Average a numeric field + Average(FieldRef), + /// Median of a numeric field + Median(FieldRef), +} + +/// The result of executing a query +#[derive(Debug)] +pub enum QueryResult<'a> { + /// Standard entity results (no aggregation) + Entities(Vec<&'a Entity>), + /// Aggregation result + Aggregation(AggregationResult), +} + +/// Result of an aggregation operation +#[derive(Debug, Clone, PartialEq, Serialize)] +pub enum AggregationResult { + /// Rows of field values from a select query + Select { + columns: Vec, + rows: Vec>>, + }, + /// A count result + Count(usize), + /// A sum result + Sum(AggregateValue), + /// An average result + Average(f64), + /// A median result + Median(f64), +} + +impl fmt::Display for AggregationResult { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AggregationResult::Count(n) => write!(f, "{}", n), + AggregationResult::Sum(val) => write!(f, "{}", val), + AggregationResult::Average(val) => write!(f, "{}", val), + AggregationResult::Median(val) => write!(f, "{}", val), + AggregationResult::Select { columns, rows } => { + writeln!(f, "{}", columns.join("\t"))?; + for row in rows { + let cells: Vec = row + .iter() + .map(|v| match v { + Some(val) => val.to_string(), + None => "-".to_string(), + }) + .collect(); + writeln!(f, "{}", cells.join("\t"))?; + } + Ok(()) + } + } + } +} + +/// A value produced by a numeric aggregation +#[derive(Debug, Clone, PartialEq, Serialize)] +pub enum AggregateValue { + Integer(i64), + Float(f64), + Currency { + amount: Decimal, + currency: Currency, + }, +} + +impl fmt::Display for AggregateValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AggregateValue::Integer(n) => write!(f, "{}", n), + AggregateValue::Float(n) => write!(f, "{}", n), + AggregateValue::Currency { amount, currency } => { + write!(f, "{} {}", amount, currency.code()) + } + } + } +} + /// A query that can be executed against an entity graph #[derive(Debug, Clone)] pub struct Query { pub from: EntitySelector, pub operations: Vec, + pub aggregation: Option, } impl Query { @@ -28,6 +125,7 @@ impl Query { Self { from, operations: Vec::new(), + aggregation: None, } } @@ -37,11 +135,17 @@ impl Query { self } + /// Set the terminal aggregation for the query + pub fn with_aggregation(mut self, aggregation: Aggregation) -> Self { + self.aggregation = Some(aggregation); + self + } + /// Execute the query against an entity graph pub fn execute<'a>( &self, graph: &'a crate::graph::EntityGraph, - ) -> Result, QueryError> { + ) -> Result, QueryError> { // Start by selecting entities based on the "from" clause let mut entities = match &self.from { EntitySelector::Type(entity_type) => { @@ -95,7 +199,14 @@ impl Query { }; } - Ok(entities) + // Apply terminal aggregation if present + match &self.aggregation { + None => Ok(QueryResult::Entities(entities)), + Some(aggregation) => { + let result = aggregation.execute(&entities)?; + Ok(QueryResult::Aggregation(result)) + } + } } } @@ -133,6 +244,14 @@ mod tests { use super::*; use crate::{Entity, EntityId, EntityType, FieldId, FieldValue}; + /// Helper to extract entities from a QueryResult, panicking if it's an aggregation. + fn unwrap_entities<'a>(result: QueryResult<'a>) -> Vec<&'a Entity> { + match result { + QueryResult::Entities(entities) => entities, + QueryResult::Aggregation(_) => panic!("Expected entities, got aggregation"), + } + } + fn create_test_graph() -> crate::graph::EntityGraph { let mut graph = crate::graph::EntityGraph::new(); @@ -164,7 +283,7 @@ mod tests { fn test_query_from_type() { let graph = create_test_graph(); let query = Query::new(EntitySelector::Type(EntityType::new("person"))); - let results = query.execute(&graph).unwrap(); + let results = unwrap_entities(query.execute(&graph).unwrap()); assert_eq!(results.len(), 2); assert!(results.iter().any(|e| e.id == EntityId::new("person1"))); @@ -175,7 +294,7 @@ mod tests { fn test_query_from_all() { let graph = create_test_graph(); let query = Query::new(EntitySelector::All); - let results = query.execute(&graph).unwrap(); + let results = unwrap_entities(query.execute(&graph).unwrap()); assert_eq!(results.len(), 4); } @@ -193,7 +312,7 @@ mod tests { )), ); - let results = query.execute(&graph).unwrap(); + let results = unwrap_entities(query.execute(&graph).unwrap()); assert_eq!(results.len(), 1); assert_eq!(results[0].id, EntityId::new("task2")); } @@ -203,7 +322,7 @@ mod tests { let graph = create_test_graph(); let query = Query::new(EntitySelector::All).with_operation(QueryOperation::Limit(2)); - let results = query.execute(&graph).unwrap(); + let results = unwrap_entities(query.execute(&graph).unwrap()); assert_eq!(results.len(), 2); } @@ -222,7 +341,7 @@ mod tests { )) .with_operation(QueryOperation::Limit(1)); - let results = query.execute(&graph).unwrap(); + let results = unwrap_entities(query.execute(&graph).unwrap()); assert_eq!(results.len(), 1); } @@ -246,4 +365,48 @@ mod tests { assert!(available.contains(&"person".to_string())); } } + + // --- Aggregation integration tests --- + + fn unwrap_aggregation(result: QueryResult) -> AggregationResult { + match result { + QueryResult::Aggregation(agg) => agg, + QueryResult::Entities(_) => panic!("Expected aggregation, got entities"), + } + } + + #[test] + fn test_query_with_count_aggregation() { + let graph = create_test_graph(); + let query = Query::new(EntitySelector::Type(EntityType::new("person"))) + .with_aggregation(Aggregation::Count(None)); + let result = unwrap_aggregation(query.execute(&graph).unwrap()); + assert_eq!(result, AggregationResult::Count(2)); + } + + #[test] + fn test_query_with_aggregation_after_where() { + let graph = create_test_graph(); + let query = Query::new(EntitySelector::Type(EntityType::new("task"))) + .with_operation(QueryOperation::Where( + super::super::CompoundFilterCondition::single( + super::super::FilterCondition::new( + super::super::FieldRef::Regular(FieldId::new("is_completed")), + super::super::FilterOperator::Equal, + super::super::FilterValue::Boolean(false), + ), + ), + )) + .with_aggregation(Aggregation::Count(None)); + let result = unwrap_aggregation(query.execute(&graph).unwrap()); + assert_eq!(result, AggregationResult::Count(1)); + } + + #[test] + fn test_query_without_aggregation_returns_entities() { + let graph = create_test_graph(); + let query = Query::new(EntitySelector::Type(EntityType::new("person"))); + let result = query.execute(&graph).unwrap(); + assert!(matches!(result, QueryResult::Entities(_))); + } } diff --git a/firm_lang/src/convert/to_query.rs b/firm_lang/src/convert/to_query.rs index 1c63b7c..5a3815b 100644 --- a/firm_lang/src/convert/to_query.rs +++ b/firm_lang/src/convert/to_query.rs @@ -1,8 +1,8 @@ //! Conversion from ParsedQuery to executable Query use firm_core::graph::{ - Combinator, CompoundFilterCondition, EntitySelector, FieldRef, FilterCondition, FilterOperator, - FilterValue, MetadataField, Query, QueryOperation, SortDirection, + Aggregation, Combinator, CompoundFilterCondition, EntitySelector, FieldRef, FilterCondition, + FilterOperator, FilterValue, MetadataField, Query, QueryOperation, SortDirection, }; use firm_core::{EntityType, FieldId}; @@ -49,6 +49,12 @@ impl TryFrom for Query { query = query.with_operation(operation); } + // Convert optional aggregation + if let Some(parsed_agg) = parsed.aggregation { + let aggregation = convert_aggregation(parsed_agg)?; + query = query.with_aggregation(aggregation); + } + Ok(query) } } @@ -118,6 +124,23 @@ fn convert_related( }) } +fn convert_aggregation( + parsed: ParsedAggregation, +) -> Result { + match parsed { + ParsedAggregation::Select(fields) => { + let field_refs: Vec = fields.into_iter().map(convert_field).collect(); + Ok(Aggregation::Select(field_refs)) + } + ParsedAggregation::Count(field) => { + Ok(Aggregation::Count(field.map(convert_field))) + } + ParsedAggregation::Sum(field) => Ok(Aggregation::Sum(convert_field(field))), + ParsedAggregation::Average(field) => Ok(Aggregation::Average(convert_field(field))), + ParsedAggregation::Median(field) => Ok(Aggregation::Median(convert_field(field))), + } +} + fn convert_field(parsed: ParsedField) -> FieldRef { match parsed { ParsedField::Metadata(name) => { diff --git a/firm_lang/src/parser/query/grammar.pest b/firm_lang/src/parser/query/grammar.pest index 9af1559..2676ded 100644 --- a/firm_lang/src/parser/query/grammar.pest +++ b/firm_lang/src/parser/query/grammar.pest @@ -1,7 +1,7 @@ WHITESPACE = _{ " " | "\t" | "\n" } -// Top-level query: "from | where ... | order ... | limit ..." -query = { SOI ~ from_clause ~ ("|" ~ operation)* ~ EOI } +// Top-level query: "from | where ... | order ... | limit ... | count" +query = { SOI ~ from_clause ~ ("|" ~ operation)* ~ ("|" ~ aggregation)? ~ EOI } // FROM clause: "from task" or "from *" from_clause = { "from" ~ entity_selector } @@ -105,3 +105,22 @@ enum_value = { "enum" ~ string } list = { "[" ~ value ~ ("," ~ value)* ~ "]" } identifier = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* } + +// Aggregation clauses (terminal — must be the last clause in a query) +aggregation = { + select_clause + | count_clause + | sum_clause + | average_clause + | median_clause +} + +select_clause = { "select" ~ select_field ~ ("," ~ select_field)* } +select_field = { metadata_field | field_name } + +count_clause = { "count" ~ (metadata_field | field_name)? } +sum_clause = { "sum" ~ aggregation_field } +average_clause = { "average" ~ aggregation_field } +median_clause = { "median" ~ aggregation_field } + +aggregation_field = { metadata_field | field_name } diff --git a/firm_lang/src/parser/query/parsed_query.rs b/firm_lang/src/parser/query/parsed_query.rs index 9e8cb34..476dd88 100644 --- a/firm_lang/src/parser/query/parsed_query.rs +++ b/firm_lang/src/parser/query/parsed_query.rs @@ -7,6 +7,7 @@ use std::fmt; pub struct ParsedQuery { pub from: ParsedFromClause, pub operations: Vec, + pub aggregation: Option, } /// The FROM clause specifies the starting entity type(s) @@ -37,6 +38,21 @@ pub enum ParsedOperation { Limit(usize), } +/// Terminal aggregation clause +#[derive(Debug, Clone, PartialEq)] +pub enum ParsedAggregation { + /// Select specific fields: select @id, name, status + Select(Vec), + /// Count entities: count (all) or count field_name (entities with field) + Count(Option), + /// Sum a numeric field: sum amount + Sum(ParsedField), + /// Average a numeric field: average age + Average(ParsedField), + /// Median of a numeric field: median salary + Median(ParsedField), +} + /// A compound condition combining multiple conditions with AND/OR #[derive(Debug, Clone, PartialEq)] pub struct ParsedCompoundCondition { diff --git a/firm_lang/src/parser/query/parser.rs b/firm_lang/src/parser/query/parser.rs index d01e303..c4e083c 100644 --- a/firm_lang/src/parser/query/parser.rs +++ b/firm_lang/src/parser/query/parser.rs @@ -34,6 +34,7 @@ pub fn parse_query(input: &str) -> Result { let mut from_clause = None; let mut operations = Vec::new(); + let mut aggregation = None; for pair in pairs { if pair.as_rule() == Rule::query { @@ -45,6 +46,9 @@ pub fn parse_query(input: &str) -> Result { Rule::operation => { operations.push(parse_operation(inner_pair)?); } + Rule::aggregation => { + aggregation = Some(parse_aggregation(inner_pair)?); + } Rule::EOI => {} _ => {} } @@ -56,7 +60,11 @@ pub fn parse_query(input: &str) -> Result { QueryParseError::SyntaxError("Query must start with 'from' clause".to_string()) })?; - Ok(ParsedQuery { from, operations }) + Ok(ParsedQuery { + from, + operations, + aggregation, + }) } fn parse_from_clause( @@ -417,3 +425,124 @@ fn parse_limit_clause( "Invalid limit clause".to_string(), )) } + +// --- Aggregation parsing --- + +fn parse_aggregation( + pair: pest::iterators::Pair, +) -> Result { + let inner_pair = pair + .into_inner() + .next() + .ok_or_else(|| QueryParseError::SyntaxError("Empty aggregation".to_string()))?; + + match inner_pair.as_rule() { + Rule::select_clause => parse_select_clause(inner_pair), + Rule::count_clause => parse_count_clause(inner_pair), + Rule::sum_clause => parse_sum_clause(inner_pair), + Rule::average_clause => parse_average_clause(inner_pair), + Rule::median_clause => parse_median_clause(inner_pair), + _ => Err(QueryParseError::SyntaxError(format!( + "Unknown aggregation: {:?}", + inner_pair.as_rule() + ))), + } +} + +fn parse_select_clause( + pair: pest::iterators::Pair, +) -> Result { + let mut fields = Vec::new(); + for inner_pair in pair.into_inner() { + if inner_pair.as_rule() == Rule::select_field { + fields.push(parse_field_ref(inner_pair)?); + } + } + if fields.is_empty() { + return Err(QueryParseError::SyntaxError( + "Select requires at least one field".to_string(), + )); + } + Ok(ParsedAggregation::Select(fields)) +} + +fn parse_count_clause( + pair: pest::iterators::Pair, +) -> Result { + let field = pair + .into_inner() + .next() + .map(|p| parse_field_from_rule(p)) + .transpose()?; + Ok(ParsedAggregation::Count(field)) +} + +fn parse_sum_clause( + pair: pest::iterators::Pair, +) -> Result { + let field = parse_aggregation_field(pair)?; + Ok(ParsedAggregation::Sum(field)) +} + +fn parse_average_clause( + pair: pest::iterators::Pair, +) -> Result { + let field = parse_aggregation_field(pair)?; + Ok(ParsedAggregation::Average(field)) +} + +fn parse_median_clause( + pair: pest::iterators::Pair, +) -> Result { + let field = parse_aggregation_field(pair)?; + Ok(ParsedAggregation::Median(field)) +} + +fn parse_aggregation_field( + pair: pest::iterators::Pair, +) -> Result { + let field_pair = pair + .into_inner() + .find(|p| p.as_rule() == Rule::aggregation_field) + .ok_or_else(|| { + QueryParseError::SyntaxError("Missing field in aggregation".to_string()) + })?; + + let inner = field_pair.into_inner().next().ok_or_else(|| { + QueryParseError::SyntaxError("Invalid aggregation field".to_string()) + })?; + + parse_field_from_rule(inner) +} + +/// Parse a field reference from a select_field or aggregation_field wrapper rule. +fn parse_field_ref(pair: pest::iterators::Pair) -> Result { + let inner = pair.into_inner().next().ok_or_else(|| { + QueryParseError::SyntaxError("Invalid field reference".to_string()) + })?; + parse_field_from_rule(inner) +} + +/// Parse a metadata_field or field_name rule into a ParsedField. +fn parse_field_from_rule( + pair: pest::iterators::Pair, +) -> Result { + match pair.as_rule() { + Rule::metadata_field => { + let name = pair + .into_inner() + .next() + .ok_or_else(|| { + QueryParseError::SyntaxError("Invalid metadata field".to_string()) + })? + .as_str() + .to_string(); + Ok(ParsedField::Metadata(name)) + } + Rule::field_name => Ok(ParsedField::Regular(pair.as_str().to_string())), + _ => Err(QueryParseError::SyntaxError(format!( + "Expected field, got {:?}", + pair.as_rule() + ))), + } +} diff --git a/firm_lang/tests/parser_query_tests.rs b/firm_lang/tests/parser_query_tests.rs index 5a40e87..9328540 100644 --- a/firm_lang/tests/parser_query_tests.rs +++ b/firm_lang/tests/parser_query_tests.rs @@ -1,8 +1,8 @@ //! Tests for query language parsing use firm_lang::parser::query::{ - ParsedCombinator, ParsedDirection, ParsedEntitySelector, ParsedField, ParsedOperation, - ParsedQueryValue, parse_query, + ParsedAggregation, ParsedCombinator, ParsedDirection, ParsedEntitySelector, ParsedField, + ParsedOperation, ParsedQueryValue, parse_query, }; #[test] @@ -213,3 +213,97 @@ fn test_parse_compound_condition_mixed_error() { let result = parse_query(query_str); assert!(result.is_err()); } + +// --- Aggregation parsing tests --- + +#[test] +fn test_parse_count_no_field() { + let query = parse_query("from task | count").unwrap(); + assert_eq!(query.aggregation, Some(ParsedAggregation::Count(None))); +} + +#[test] +fn test_parse_count_with_field() { + let query = parse_query("from task | count assignee").unwrap(); + assert_eq!( + query.aggregation, + Some(ParsedAggregation::Count(Some(ParsedField::Regular( + "assignee".to_string() + )))) + ); +} + +#[test] +fn test_parse_sum() { + let query = parse_query("from invoice | sum amount").unwrap(); + assert_eq!( + query.aggregation, + Some(ParsedAggregation::Sum(ParsedField::Regular( + "amount".to_string() + ))) + ); +} + +#[test] +fn test_parse_average() { + let query = parse_query("from employee | average age").unwrap(); + assert_eq!( + query.aggregation, + Some(ParsedAggregation::Average(ParsedField::Regular( + "age".to_string() + ))) + ); +} + +#[test] +fn test_parse_median() { + let query = parse_query("from employee | median salary").unwrap(); + assert_eq!( + query.aggregation, + Some(ParsedAggregation::Median(ParsedField::Regular( + "salary".to_string() + ))) + ); +} + +#[test] +fn test_parse_select_single_field() { + let query = parse_query("from project | select name").unwrap(); + assert_eq!( + query.aggregation, + Some(ParsedAggregation::Select(vec![ParsedField::Regular( + "name".to_string() + )])) + ); +} + +#[test] +fn test_parse_select_multiple_fields() { + let query = parse_query("from task | select @id, name, priority").unwrap(); + assert_eq!( + query.aggregation, + Some(ParsedAggregation::Select(vec![ + ParsedField::Metadata("id".to_string()), + ParsedField::Regular("name".to_string()), + ParsedField::Regular("priority".to_string()), + ])) + ); +} + +#[test] +fn test_parse_aggregation_after_operations() { + let query = parse_query("from task | where is_completed == false | count").unwrap(); + assert_eq!(query.operations.len(), 1); + assert!(matches!( + query.operations[0], + ParsedOperation::Where(_) + )); + assert_eq!(query.aggregation, Some(ParsedAggregation::Count(None))); +} + +#[test] +fn test_parse_query_without_aggregation_unchanged() { + let query = parse_query("from task | limit 5").unwrap(); + assert_eq!(query.aggregation, None); + assert_eq!(query.operations.len(), 1); +} diff --git a/firm_mcp/src/server.rs b/firm_mcp/src/server.rs index b2b8ecf..a551c87 100644 --- a/firm_mcp/src/server.rs +++ b/firm_mcp/src/server.rs @@ -133,8 +133,14 @@ impl FirmMcpServer { } #[tool( - description = "Query entities using the Firm query language. Returns full details for all matching entities. \ - Examples: 'from person', 'from task | where is_completed == false', 'from person | where name contains \"John\" | limit 5'. \ + description = "Query entities using the Firm query language. Returns full details for all matching entities, \ + or an aggregated result when an aggregation clause is used. \ + Examples: 'from person', 'from task | where is_completed == false', \ + 'from task | where is_completed == false and priority > 5', \ + 'from invoice | where status == \"draft\" or status == \"sent\"', \ + 'from person | where name contains \"John\" | limit 5', \ + 'from task | count', 'from invoice | where status == \"sent\" | sum amount', \ + 'from task | where is_completed == false | select @id, name, due_date'. \ Use 'list' for a simple ID overview, or 'get' for a single entity's details." )] async fn query( @@ -363,7 +369,7 @@ impl FirmMcpServer { #[tool( description = "Get reference documentation for the Firm DSL syntax and query language. \ Use 'topic' parameter: 'dsl' for DSL syntax (entities, schemas, field types), \ - 'query' for query language (from, where, related, order, limit), \ + 'query' for query language (from, where, related, order, limit, aggregations), \ or 'all' for both (default). \ Call this before writing or modifying .firm files to understand the correct syntax." )] diff --git a/firm_mcp/src/tools/dsl_reference_content.rs b/firm_mcp/src/tools/dsl_reference_content.rs index c1c0266..f4c28b1 100644 --- a/firm_mcp/src/tools/dsl_reference_content.rs +++ b/firm_mcp/src/tools/dsl_reference_content.rs @@ -231,6 +231,49 @@ from task | limit 10 from task | where priority > 8 | order priority desc | limit 5 ``` +## Aggregations + +An optional final clause that summarizes the result set instead of returning entities. + +### select - Extract field values + +```bash +from person | select name +from task | where is_completed == false | select @id, name, due_date +``` + +### count - Count entities + +```bash +from task | where is_completed == false | count +from person | count email +``` + +Without a field, counts all entities. With a field, counts entities that have that field. + +### sum - Sum numeric field + +```bash +from line_item | sum quantity +from invoice | where status == "sent" | sum amount +``` + +Works with integer, float, and currency fields. Mixed currencies produce an error. + +### average - Mean of numeric field + +```bash +from task | average estimated_hours +``` + +### median - Median of numeric field + +```bash +from task | median estimated_hours +``` + +For all numeric aggregations, entities missing the field are skipped. + ## Example Queries ```bash @@ -248,6 +291,15 @@ from project | where status == "active" | related task # Complex multi-hop query from organization | where industry == "tech" | related(2) task | where is_completed == false | order due_date | limit 10 + +# Count incomplete tasks +from task | where is_completed == false | count + +# Total invoice amount +from invoice | where status == "sent" | sum amount + +# Extract specific fields +from task | where is_completed == false | select @id, name, due_date ``` ## Query Execution @@ -255,10 +307,10 @@ from organization | where industry == "tech" | related(2) task | where is_comple Queries execute left to right, each operation transforming the result set: ``` -from task → [all tasks] -| where status → [filtered tasks] -| related project → [related projects] -| order name → [sorted projects] -| limit 5 → [top 5 projects] +from task → [all tasks] +| where is_completed == false → [filtered tasks] +| related project → [related projects] +| order name → [sorted projects] +| limit 5 → [top 5 projects] ``` "#; diff --git a/firm_mcp/src/tools/query.rs b/firm_mcp/src/tools/query.rs index 758b846..e6a5f3e 100644 --- a/firm_mcp/src/tools/query.rs +++ b/firm_mcp/src/tools/query.rs @@ -1,6 +1,6 @@ //! Query tool implementation. -use firm_core::graph::{EntityGraph, Query}; +use firm_core::graph::{EntityGraph, Query, QueryResult}; use firm_lang::parser::query::parse_query; use rmcp::model::{CallToolResult, Content}; use rmcp::schemars; @@ -42,7 +42,7 @@ pub fn execute(graph: &EntityGraph, params: &QueryParams) -> CallToolResult { }; // Execute the query - let results = match query.execute(graph) { + let result = match query.execute(graph) { Ok(r) => r, Err(e) => { return CallToolResult::error(vec![Content::text(format!( @@ -53,12 +53,18 @@ pub fn execute(graph: &EntityGraph, params: &QueryParams) -> CallToolResult { }; // Format results - if results.is_empty() { - return CallToolResult::success(vec![Content::text( - "No entities found matching the query.", - )]); + match result { + QueryResult::Entities(entities) => { + if entities.is_empty() { + return CallToolResult::success(vec![Content::text( + "No entities found matching the query.", + )]); + } + let output: Vec = entities.iter().map(|e| e.to_string()).collect(); + CallToolResult::success(vec![Content::text(output.join("\n---\n"))]) + } + QueryResult::Aggregation(agg_result) => { + CallToolResult::success(vec![Content::text(agg_result.to_string())]) + } } - - let output: Vec = results.iter().map(|e| e.to_string()).collect(); - CallToolResult::success(vec![Content::text(output.join("\n---\n"))]) }