From a6e265f6a7f0710a0eeda0387091a119f1c4f341 Mon Sep 17 00:00:00 2001 From: "Z.-L. Deng" Date: Tue, 7 Oct 2025 23:34:28 +0200 Subject: [PATCH 1/9] Improve numeric comparisons and diagnostics --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/common.rs | 15 +++++++++++++ src/expression.rs | 35 ++++++++++++++++++++++++++--- src/join.rs | 57 ++++++++++++++++++++++++++++++++++++++++------- src/melt.rs | 23 +++++++++++++++---- src/mutate.rs | 15 +++++-------- src/pivot.rs | 23 +++++++++++++++---- src/pretty.rs | 15 ++++++++++--- src/sort.rs | 23 +++++++++++++++---- 10 files changed, 172 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8fe21a0..5f4ca84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -603,7 +603,7 @@ dependencies = [ [[package]] name = "tsvkit" -version = "0.9.2" +version = "0.9.4" dependencies = [ "anyhow", "calamine", diff --git a/Cargo.toml b/Cargo.toml index baf0318..8e88aed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tsvkit" -version = "0.9.3" +version = "0.9.4" edition = "2024" [dependencies] diff --git a/src/common.rs b/src/common.rs index 06bc5c1..09ee1f3 100644 --- a/src/common.rs +++ b/src/common.rs @@ -207,6 +207,21 @@ pub fn should_skip_record( false } +pub fn inconsistent_width_error( + source: &str, + row_number: usize, + expected: usize, + actual: usize, +) -> anyhow::Error { + anyhow!( + "rows in {} have inconsistent column counts at row {} (expected {}, got {})", + source, + row_number, + expected, + actual, + ) +} + pub fn default_headers(len: usize) -> Vec { (1..=len).map(|i| format!("col{}", i)).collect() } diff --git a/src/expression.rs b/src/expression.rs index 6f283de..34cacfe 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -408,8 +408,20 @@ where let right = eval_value(rhs, row); match op { - CompareOp::Eq => left.text == right.text, - CompareOp::Ne => left.text != right.text, + CompareOp::Eq => { + if let (Some(a), Some(b)) = (left.numeric, right.numeric) { + a == b + } else { + left.text == right.text + } + } + CompareOp::Ne => { + if let (Some(a), Some(b)) = (left.numeric, right.numeric) { + a != b + } else { + left.text != right.text + } + } CompareOp::Gt => compare_numeric(&left, &right, |a, b| a > b), CompareOp::Ge => compare_numeric(&left, &right, |a, b| a >= b), CompareOp::Lt => compare_numeric(&left, &right, |a, b| a < b), @@ -636,6 +648,24 @@ mod tests { assert!(evaluate(&bound, &row)); } + #[test] + fn numeric_equality_uses_numeric_comparison() { + let expr = parse_expression("$1 == 0").unwrap(); + let headers = vec!["value".to_string()]; + let bound = bind_expression(expr, &headers, false).unwrap(); + let record = csv::StringRecord::from(vec!["0.0"]); + assert!(evaluate(&bound, &record)); + } + + #[test] + fn string_equality_remains_textual_when_not_numeric() { + let expr = parse_expression("$1 == 0").unwrap(); + let headers = vec!["value".to_string()]; + let bound = bind_expression(expr, &headers, false).unwrap(); + let record = csv::StringRecord::from(vec!["00a"]); + assert!(!evaluate(&bound, &record)); + } + #[test] fn subtraction_without_spaces_between_columns() { let expr = parse_expression("($dna_ug-$rna_ug)>10").unwrap(); @@ -1326,4 +1356,3 @@ enum TokenKind { Slash, Caret, } - diff --git a/src/join.rs b/src/join.rs index dcd9271..100c8c8 100644 --- a/src/join.rs +++ b/src/join.rs @@ -8,8 +8,9 @@ use num_cpus; use rayon::{ThreadPoolBuilder, prelude::*}; use crate::common::{ - ColumnSelector, InputOptions, default_headers, parse_multi_selector_spec, parse_selector_list, - reader_for_path, resolve_selectors, should_skip_record, + ColumnSelector, InputOptions, default_headers, inconsistent_width_error, + parse_multi_selector_spec, parse_selector_list, reader_for_path, resolve_selectors, + should_skip_record, }; #[derive(Args, Debug)] @@ -97,6 +98,8 @@ struct StreamTable { source: String, projection: Vec, input_opts: InputOptions, + row_number: usize, + header_rows: usize, } #[derive(Clone)] @@ -115,14 +118,16 @@ impl StreamTable { fill_value: &str, ) -> Result { let mut reader = reader_for_path(path, no_header, input_opts)?; - let source = path.display().to_string(); + let source = format!("\"{}\"", path.display()); if no_header { let mut records = reader.records(); let mut first_record: Option = None; let mut source_width = 0usize; + let mut row_number = 0usize; while let Some(rec) = records.next() { let record = rec.with_context(|| format!("failed reading from {}", source))?; + row_number += 1; if should_skip_record( &record, input_opts, @@ -138,7 +143,12 @@ impl StreamTable { if input_opts.ignore_illegal { continue; } else { - bail!("rows in {} have inconsistent column counts", source); + return Err(inconsistent_width_error( + &source, + row_number, + source_width, + record.len(), + )); } } source_width = record.len(); @@ -167,6 +177,8 @@ impl StreamTable { source, projection: projection_plan.projection, input_opts: input_opts.clone(), + row_number, + header_rows: 0, }); } @@ -194,6 +206,8 @@ impl StreamTable { source, projection: projection_plan.projection, input_opts: input_opts.clone(), + row_number: 0, + header_rows: 1, }) } @@ -236,11 +250,17 @@ impl StreamTable { Some(rec) => { let record = rec.with_context(|| format!("failed reading from {}", self.source))?; + self.row_number += 1; if record.len() != self.source_width { if self.input_opts.ignore_illegal { continue; } else { - bail!("rows in {} have inconsistent column counts", self.source); + return Err(inconsistent_width_error( + &self.source, + self.header_rows + self.row_number, + self.source_width, + record.len(), + )); } } if should_skip_record(&record, &self.input_opts, Some(self.source_width)) { @@ -336,13 +356,16 @@ fn load_table( fill_value: &str, ) -> Result { let mut reader = reader_for_path(path, no_header, input_opts)?; + let source_name = format!("\"{}\"", path.display()); let (headers, rows, join_indices, include_indices) = if no_header { let mut records = reader.records(); let mut first_record: Option = None; let mut expected_width: Option = None; + let mut row_number = 0usize; while let Some(rec) = records.next() { let record = rec.with_context(|| format!("failed reading from {}", path.display()))?; + row_number += 1; if should_skip_record(&record, input_opts, expected_width) { continue; } @@ -351,7 +374,12 @@ fn load_table( if input_opts.ignore_illegal { continue; } else { - bail!("rows in {} have inconsistent column counts", path.display()); + return Err(inconsistent_width_error( + &source_name, + row_number, + width, + record.len(), + )); } } } @@ -375,11 +403,17 @@ fn load_table( for rec in records { let record = rec.with_context(|| format!("failed reading from {}", path.display()))?; + row_number += 1; if record.len() != expected_width { if input_opts.ignore_illegal { continue; } else { - bail!("rows in {} have inconsistent column counts", path.display()); + return Err(inconsistent_width_error( + &source_name, + row_number, + expected_width, + record.len(), + )); } } if should_skip_record(&record, input_opts, Some(expected_width)) { @@ -408,13 +442,20 @@ fn load_table( let projection_plan = build_projection(&headers_orig, &join_indices_orig, &include_indices_orig); let mut rows = Vec::new(); + let mut row_number = 0usize; for rec in reader.records() { let record = rec.with_context(|| format!("failed reading from {}", path.display()))?; + row_number += 1; if record.len() != expected_width { if input_opts.ignore_illegal { continue; } else { - bail!("rows in {} have inconsistent column counts", path.display()); + return Err(inconsistent_width_error( + &source_name, + row_number + 1, + expected_width, + record.len(), + )); } } if should_skip_record(&record, input_opts, Some(expected_width)) { diff --git a/src/melt.rs b/src/melt.rs index 734a386..1421025 100644 --- a/src/melt.rs +++ b/src/melt.rs @@ -6,8 +6,8 @@ use clap::Args; use indexmap::IndexSet; use crate::common::{ - InputOptions, default_headers, parse_selector_list, reader_for_path, resolve_selectors, - should_skip_record, + InputOptions, default_headers, inconsistent_width_error, parse_selector_list, reader_for_path, + resolve_selectors, should_skip_record, }; #[derive(Args, Debug)] @@ -69,20 +69,28 @@ pub fn run(args: MeltArgs) -> Result<()> { args.ignore_illegal_row, )?; let mut reader = reader_for_path(&args.file, args.no_header, &input_opts)?; + let source_name = format!("\"{}\"", args.file.display()); let mut writer = BufWriter::new(io::stdout().lock()); let fill_value = args.fill.clone().unwrap_or_else(String::new); if args.no_header { let mut rows = Vec::new(); let mut expected_width: Option = None; + let mut row_number = 0usize; for record in reader.records() { let record = record.with_context(|| format!("failed reading from {:?}", args.file))?; + row_number += 1; if let Some(width) = expected_width { if record.len() != width { if input_opts.ignore_illegal { continue; } else { - bail!("rows in {:?} have inconsistent column counts", args.file); + return Err(inconsistent_width_error( + &source_name, + row_number, + width, + record.len(), + )); } } } @@ -108,13 +116,20 @@ pub fn run(args: MeltArgs) -> Result<()> { .collect::>(); let mut rows = Vec::new(); + let mut row_number = 0usize; for record in reader.records() { let record = record.with_context(|| format!("failed reading from {:?}", args.file))?; + row_number += 1; if record.len() != headers.len() { if input_opts.ignore_illegal { continue; } else { - bail!("rows in {:?} have inconsistent column counts", args.file); + return Err(inconsistent_width_error( + &source_name, + row_number + 1, + headers.len(), + record.len(), + )); } } if should_skip_record(&record, &input_opts, Some(headers.len())) { diff --git a/src/mutate.rs b/src/mutate.rs index 745cf48..fb23305 100644 --- a/src/mutate.rs +++ b/src/mutate.rs @@ -371,10 +371,10 @@ fn parse_substitution_expression( let content = content .strip_suffix('/') .with_context(|| "substitution expression must end with '/'")?; - let (selector_part, pattern_part, replacement_part) = - split_substitution_components(content).with_context(|| { - "substitution expression must use s/selectors/pattern/replacement/ syntax" - })?; + let (selector_part, pattern_part, replacement_part) = split_substitution_components(content) + .with_context( + || "substitution expression must use s/selectors/pattern/replacement/ syntax", + )?; let selectors = parse_selector_list(&normalize_selector_spec(selector_part.trim()))?; if selectors.is_empty() { @@ -630,12 +630,7 @@ mod tests { #[test] fn substitution_replacement_supports_escape_sequences() { let headers = vec!["col1".to_string()]; - let ops = parse_operations( - &vec!["s/$col1/\\t/ /".to_string()], - &headers, - false, - ) - .unwrap(); + let ops = parse_operations(&vec!["s/$col1/\\t/ /".to_string()], &headers, false).unwrap(); let mut row = vec!["field\tvalue".to_string()]; process_row(&mut row, &ops).unwrap(); diff --git a/src/pivot.rs b/src/pivot.rs index dddecad..58d074c 100644 --- a/src/pivot.rs +++ b/src/pivot.rs @@ -6,8 +6,8 @@ use clap::Args; use indexmap::{IndexMap, IndexSet}; use crate::common::{ - InputOptions, default_headers, parse_selector_list, parse_single_selector, reader_for_path, - resolve_selectors, should_skip_record, + InputOptions, default_headers, inconsistent_width_error, parse_selector_list, + parse_single_selector, reader_for_path, resolve_selectors, should_skip_record, }; #[derive(Args, Debug)] @@ -76,19 +76,27 @@ pub fn run(args: PivotArgs) -> Result<()> { )?; let mut reader = reader_for_path(&args.file, args.no_header, &input_opts)?; + let source_name = format!("\"{}\"", args.file.display()); let mut writer = BufWriter::new(io::stdout().lock()); let headers = if args.no_header { let mut all_rows = Vec::new(); let mut expected_width: Option = None; + let mut row_number = 0usize; for record in reader.records() { let record = record.with_context(|| format!("failed reading from {:?}", args.file))?; + row_number += 1; if let Some(width) = expected_width { if record.len() != width { if input_opts.ignore_illegal { continue; } else { - bail!("rows in {:?} have inconsistent column counts", args.file); + return Err(inconsistent_width_error( + &source_name, + row_number, + width, + record.len(), + )); } } } @@ -122,13 +130,20 @@ pub fn run(args: PivotArgs) -> Result<()> { }; let mut rows = Vec::new(); + let mut row_number = 0usize; for record in reader.records() { let record = record.with_context(|| format!("failed reading from {:?}", args.file))?; + row_number += 1; if record.len() != headers.len() { if input_opts.ignore_illegal { continue; } else { - bail!("rows in {:?} have inconsistent column counts", args.file); + return Err(inconsistent_width_error( + &source_name, + row_number + 1, + headers.len(), + record.len(), + )); } } if should_skip_record(&record, &input_opts, Some(headers.len())) { diff --git a/src/pretty.rs b/src/pretty.rs index 6dbe6de..914c6c6 100644 --- a/src/pretty.rs +++ b/src/pretty.rs @@ -1,11 +1,11 @@ use std::io::{self, BufWriter, Write}; use std::path::PathBuf; -use anyhow::{Context, Result, bail}; +use anyhow::{Context, Result}; use clap::Args; use crate::aggregate::parse_float; -use crate::common::{InputOptions, reader_for_path, should_skip_record}; +use crate::common::{InputOptions, inconsistent_width_error, reader_for_path, should_skip_record}; #[derive(Args, Debug)] #[command( @@ -50,6 +50,7 @@ pub fn run(args: PrettyArgs) -> Result<()> { args.ignore_illegal_row, )?; let mut reader = reader_for_path(&args.file, args.no_header, &input_opts)?; + let source_name = format!("\"{}\"", args.file.display()); let header = if args.no_header { None @@ -66,14 +67,22 @@ pub fn run(args: PrettyArgs) -> Result<()> { let mut rows = Vec::new(); let mut reference_width = header.as_ref().map(|h| h.len()); + let header_rows = header.as_ref().map(|_| 1).unwrap_or(0); + let mut row_number = 0usize; for record in reader.records() { let record = record.with_context(|| format!("failed reading from {:?}", args.file))?; + row_number += 1; if let Some(width) = reference_width { if record.len() != width { if input_opts.ignore_illegal { continue; } else { - bail!("rows in {:?} have inconsistent column counts", args.file); + return Err(inconsistent_width_error( + &source_name, + row_number + header_rows, + width, + record.len(), + )); } } } diff --git a/src/sort.rs b/src/sort.rs index ec4d4d7..aa88263 100644 --- a/src/sort.rs +++ b/src/sort.rs @@ -7,8 +7,8 @@ use clap::Args; use csv::StringRecord; use crate::common::{ - InputOptions, default_headers, parse_single_selector, reader_for_path, resolve_single_selector, - should_skip_record, + InputOptions, default_headers, inconsistent_width_error, parse_single_selector, + reader_for_path, resolve_single_selector, should_skip_record, }; #[derive(Args, Debug)] @@ -91,18 +91,26 @@ pub fn run(args: SortArgs) -> Result<()> { )?; let mut reader = reader_for_path(&args.file, args.no_header, &input_opts)?; let mut writer = BufWriter::new(io::stdout().lock()); + let source_name = format!("\"{}\"", args.file.display()); let mut records: Vec = Vec::new(); let headers = if args.no_header { let mut reference_width: Option = None; + let mut row_number = 0usize; for record in reader.records() { let record = record.with_context(|| format!("failed reading from {:?}", args.file))?; + row_number += 1; if let Some(width) = reference_width { if record.len() != width { if input_opts.ignore_illegal { continue; } else { - bail!("rows in {:?} have inconsistent column counts", args.file); + return Err(inconsistent_width_error( + &source_name, + row_number, + width, + record.len(), + )); } } } @@ -123,13 +131,20 @@ pub fn run(args: SortArgs) -> Result<()> { .iter() .map(|s| s.to_string()) .collect::>(); + let mut row_number = 0usize; for record in reader.records() { let record = record.with_context(|| format!("failed reading from {:?}", args.file))?; + row_number += 1; if record.len() != header_row.len() { if input_opts.ignore_illegal { continue; } else { - bail!("rows in {:?} have inconsistent column counts", args.file); + return Err(inconsistent_width_error( + &source_name, + row_number + 1, + header_row.len(), + record.len(), + )); } } if should_skip_record(&record, &input_opts, Some(header_row.len())) { From d00f7b7d1a8d1ebce3255197f363a26cc6f12211 Mon Sep 17 00:00:00 2001 From: "Z.-L. Deng" Date: Thu, 9 Oct 2025 08:49:14 +0200 Subject: [PATCH 2/9] Add branching and regex helpers to mutate expressions --- Cargo.lock | 2 +- Cargo.toml | 2 +- README.md | 31 ++ src/expression.rs | 727 ++++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 700 insertions(+), 62 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5f4ca84..6d80c91 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -603,7 +603,7 @@ dependencies = [ [[package]] name = "tsvkit" -version = "0.9.4" +version = "0.9.5" dependencies = [ "anyhow", "calamine", diff --git a/Cargo.toml b/Cargo.toml index 8e88aed..9be5e12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tsvkit" -version = "0.9.4" +version = "0.9.5" edition = "2024" [dependencies] diff --git a/README.md b/README.md index e7c23b8..3be8153 100644 --- a/README.md +++ b/README.md @@ -165,9 +165,27 @@ The same expression language powers `filter -e`, `mutate -e name=EXPR`, and rege | `ln(expr)` | Natural logarithm | | `log(expr)` / `log10(expr)` | Base-10 logarithm | | `log2(expr)` | Base-2 logarithm | +| `len(expr)` | Character count using Unicode code points. | +| `is_na(expr)` | Returns `1` when the expression is blank/`NA`/`NaN`, otherwise `0`. | Functions accept column references (`abs($purity - 1)`), constants, or subexpressions. Empty or non-numeric values yield blanks. +**Conditional and regex helpers** + +- `case_when(condition -> result, ..., _ -> default)` evaluates each boolean condition in order and returns the matching result. The final `_` branch acts as the default. +- `switch(value, [match1, match2] -> result, ..., _ -> default)` compares `value` to one or more literal matches (strings or numbers) and returns the corresponding result. +- `re(value, pattern)` evaluates a regex against `value`, returning `1` or `0`. When the pattern matches, capture groups become available as `$1`, `$2`, etc. for the remainder of the expression (use `$0`-style numeric selectors sparingly when you rely on captures). + +Example: + +``` +case_when( + re($sample, "^ERR(\\d+)$") -> $1, + re($sample, "^SRR") -> "SRA", + _ -> $sample +) +``` + **Row-wise aggregation helpers** Available within `mutate` expressions via functions such as `sum($col1:$col5)`; see the [Mutate](#mutate) section for the full list. @@ -260,6 +278,19 @@ tsvkit mutate \ examples/cytokines.tsv ``` +Use `case_when`, `switch`, and the `re()` helper for richer branching logic and regex capture reuse: + +```bash +tsvkit mutate \ + -e 'label = case_when( + re($sample, "^ERR(\d+)$") -> $1, + re($sample, "^SRR") -> "SRA", + _ -> $sample + )' \ + -e 'bucket = case_when(len($clean) == 0 -> "empty", len($clean) < 5 -> "short", _ -> "long")' \ + examples/samples.tsv +``` + Apply in-place edits with the sed-style form: ```bash diff --git a/src/expression.rs b/src/expression.rs index 34cacfe..e476fb9 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -28,6 +28,13 @@ pub enum ValueExpr { Binary(BinaryOp, Box, Box), Function(FunctionName, Box), Aggregate(AggregateSpecExpr), + CaseWhen(Vec<(Expr, ValueExpr)>, Option>), + Switch { + target: Box, + branches: Vec<(Vec, ValueExpr)>, + default: Option>, + }, + RegexCall(Box, Box), } #[derive(Debug, Clone)] @@ -59,6 +66,8 @@ pub enum FunctionName { Ln, Log10, Log2, + Len, + IsNa, } impl FunctionName { @@ -71,8 +80,10 @@ impl FunctionName { "ln" => Ok(FunctionName::Ln), "log" | "log10" => Ok(FunctionName::Log10), "log2" => Ok(FunctionName::Log2), + "len" => Ok(FunctionName::Len), + "is_na" => Ok(FunctionName::IsNa), other => bail!( - "unsupported function '{}': try abs, sqrt, exp, exp2, ln, log, log10, log2", + "unsupported function '{}': try abs, sqrt, exp, exp2, ln, log, log10, log2, len, is_na", other ), } @@ -100,6 +111,10 @@ enum Token { Ident(String), LParen, RParen, + LBracket, + RBracket, + Comma, + Arrow, And, Or, Not, @@ -212,6 +227,19 @@ pub enum BoundValue { Binary(BinaryOp, Box, Box), Function(FunctionName, Box), Aggregate(BoundAggregate), + CaseWhen { + branches: Vec<(BoundExpr, BoundValue)>, + default: Option>, + }, + Switch { + target: Box, + branches: Vec<(Vec, BoundValue)>, + default: Option>, + }, + RegexCall { + value: Box, + pattern: RegexPattern, + }, } #[derive(Debug, Clone)] @@ -220,7 +248,7 @@ pub struct BoundAggregate { pub columns: Vec, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum BoundExpr { Or(Box, Box), And(Box, Box), @@ -234,7 +262,7 @@ pub enum BoundExpr { Value(BoundValue), } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) enum RegexPattern { Static(Arc), Dynamic(Box), @@ -245,31 +273,105 @@ pub struct EvalValue<'a> { pub numeric: Option, } +pub struct EvalContext<'a, R> +where + R: RowAccessor + ?Sized, +{ + row: &'a R, + regex_captures: Option>, +} + +impl<'a, R> EvalContext<'a, R> +where + R: RowAccessor + ?Sized, +{ + pub fn new(row: &'a R) -> Self { + EvalContext { + row, + regex_captures: None, + } + } + + fn row(&self) -> &'a R { + self.row + } + + fn clear_captures(&mut self) { + self.regex_captures = None; + } + + fn set_captures(&mut self, captures: Vec) { + self.regex_captures = Some(captures); + } + + fn take_captures(&mut self) -> Option> { + self.regex_captures.take() + } + + fn restore_captures(&mut self, captures: Option>) { + self.regex_captures = captures; + } +} + pub fn evaluate(expr: &BoundExpr, row: &R) -> bool +where + R: RowAccessor + ?Sized, +{ + let mut ctx = EvalContext::new(row); + evaluate_with_context(expr, &mut ctx) +} + +fn evaluate_with_context<'a, R>(expr: &'a BoundExpr, ctx: &mut EvalContext<'a, R>) -> bool where R: RowAccessor + ?Sized, { match expr { - BoundExpr::Or(lhs, rhs) => evaluate(lhs, row) || evaluate(rhs, row), - BoundExpr::And(lhs, rhs) => evaluate(lhs, row) && evaluate(rhs, row), - BoundExpr::Not(inner) => !evaluate(inner, row), - BoundExpr::Compare(lhs, op, rhs) => evaluate_compare(lhs, *op, rhs, row), + BoundExpr::Or(lhs, rhs) => { + evaluate_with_context(lhs, ctx) || evaluate_with_context(rhs, ctx) + } + BoundExpr::And(lhs, rhs) => { + evaluate_with_context(lhs, ctx) && evaluate_with_context(rhs, ctx) + } + BoundExpr::Not(inner) => !evaluate_with_context(inner, ctx), + BoundExpr::Compare(lhs, op, rhs) => evaluate_compare(lhs, *op, rhs, ctx), BoundExpr::RegexMatch { value, pattern, invert, - } => evaluate_regex(value, pattern, *invert, row), - BoundExpr::Value(value) => evaluate_truthy(value, row), + } => evaluate_regex(value, pattern, *invert, ctx), + BoundExpr::Value(value) => evaluate_truthy_with_context(value, ctx), } } pub fn eval_value<'a, R>(value: &'a BoundValue, row: &'a R) -> EvalValue<'a> +where + R: RowAccessor + ?Sized, +{ + let mut ctx = EvalContext::new(row); + eval_value_with_context(value, &mut ctx) +} + +pub fn eval_value_with_context<'a, R>( + value: &'a BoundValue, + ctx: &mut EvalContext<'a, R>, +) -> EvalValue<'a> where R: RowAccessor + ?Sized, { match value { BoundValue::Column(idx) => { - let text = row.get(*idx).unwrap_or(""); + if let Some(captures) = ctx.regex_captures.as_ref() { + let capture_idx = idx + 1; + if capture_idx < captures.len() { + let capture = captures[capture_idx].clone(); + let numeric = parse_float(&capture); + return EvalValue { + text: Cow::Owned(capture), + numeric, + }; + } + } + let text = ctx.row().get(*idx).unwrap_or(""); EvalValue { text: Cow::Borrowed(text), numeric: parse_float(text), @@ -282,7 +384,7 @@ where let mut combined = String::new(); let mut numeric = None; for (pos, idx) in indices.iter().enumerate() { - let text = row.get(*idx).unwrap_or(""); + let text = ctx.row().get(*idx).unwrap_or(""); if pos > 0 { combined.push('\t'); } @@ -302,7 +404,7 @@ where }, BoundValue::Number(number) => numeric_eval(*number), BoundValue::Unary(op, inner) => { - let inner_eval = eval_value(inner, row); + let inner_eval = eval_value_with_context(inner, ctx); if let Some(val) = inner_eval.numeric { numeric_eval(match op { UnaryOp::Neg => -val, @@ -312,8 +414,8 @@ where } } BoundValue::Binary(op, left, right) => { - let left_eval = eval_value(left, row); - let right_eval = eval_value(right, row); + let left_eval = eval_value_with_context(left, ctx); + let right_eval = eval_value_with_context(right, ctx); match (left_eval.numeric, right_eval.numeric) { (Some(a), Some(b)) => match op { BinaryOp::Add => numeric_eval(a + b), @@ -339,36 +441,58 @@ where } } BoundValue::Function(func, inner) => { - let inner_eval = eval_value(inner, row); - if let Some(val) = inner_eval.numeric { - let result = match func { - FunctionName::Abs => Some(val.abs()), - FunctionName::Sqrt => { - let value = val.sqrt(); - value.is_finite().then_some(value) - } - FunctionName::Exp => { - let value = val.exp(); - value.is_finite().then_some(value) - } - FunctionName::Exp2 => { - let value = val.exp2(); - value.is_finite().then_some(value) + let inner_eval = eval_value_with_context(inner, ctx); + match func { + FunctionName::Abs + | FunctionName::Sqrt + | FunctionName::Exp + | FunctionName::Exp2 + | FunctionName::Ln + | FunctionName::Log10 + | FunctionName::Log2 => { + if let Some(val) = inner_eval.numeric { + let result = match func { + FunctionName::Abs => Some(val.abs()), + FunctionName::Sqrt => { + let value = val.sqrt(); + value.is_finite().then_some(value) + } + FunctionName::Exp => { + let value = val.exp(); + value.is_finite().then_some(value) + } + FunctionName::Exp2 => { + let value = val.exp2(); + value.is_finite().then_some(value) + } + FunctionName::Ln => (val > 0.0).then(|| val.ln()), + FunctionName::Log10 => (val > 0.0).then(|| val.log10()), + FunctionName::Log2 => (val > 0.0).then(|| val.log2()), + _ => None, + }; + result.map(numeric_eval).unwrap_or_else(empty_eval) + } else { + empty_eval() } - FunctionName::Ln => (val > 0.0).then(|| val.ln()), - FunctionName::Log10 => (val > 0.0).then(|| val.log10()), - FunctionName::Log2 => (val > 0.0).then(|| val.log2()), - }; - result.map(numeric_eval).unwrap_or_else(empty_eval) - } else { - empty_eval() + } + FunctionName::Len => { + let len = inner_eval.text.chars().count() as f64; + numeric_eval(len) + } + FunctionName::IsNa => { + let text = inner_eval.text.as_ref().trim(); + let is_na = text.is_empty() + || text.eq_ignore_ascii_case("na") + || text.eq_ignore_ascii_case("nan"); + bool_eval(is_na) + } } } BoundValue::Aggregate(spec) => { let values = spec .columns .iter() - .map(|&idx| row.get(idx).unwrap_or("")) + .map(|&idx| ctx.row().get(idx).unwrap_or("")) .collect::>(); let result = evaluate_row_aggregate(&spec.kind, &values); EvalValue { @@ -376,21 +500,103 @@ where numeric: result.numeric, } } + BoundValue::CaseWhen { branches, default } => { + for (cond, result) in branches { + ctx.clear_captures(); + if evaluate_with_context(cond, ctx) { + return eval_value_with_context(result, ctx); + } + } + ctx.clear_captures(); + if let Some(default) = default { + eval_value_with_context(default, ctx) + } else { + empty_eval() + } + } + BoundValue::Switch { + target, + branches, + default, + } => { + let target_eval = eval_value_with_context(target, ctx); + let target_numeric = target_eval.numeric; + let target_text = target_eval.text.into_owned(); + ctx.clear_captures(); + for (values, result) in branches { + for value in values { + let saved = ctx.take_captures(); + let candidate = eval_value_with_context(value, ctx); + ctx.restore_captures(saved); + let is_match = match (target_numeric, candidate.numeric) { + (Some(a), Some(b)) => a == b, + _ => target_text == candidate.text.as_ref(), + }; + if is_match { + ctx.clear_captures(); + return eval_value_with_context(result, ctx); + } + } + } + ctx.clear_captures(); + if let Some(default) = default { + eval_value_with_context(default, ctx) + } else { + empty_eval() + } + } + BoundValue::RegexCall { value, pattern } => { + let hay = eval_value_with_context(value, ctx); + let hay_text = hay.text.into_owned(); + let captures = match pattern { + RegexPattern::Static(regex) => regex.captures(&hay_text), + RegexPattern::Dynamic(bound) => { + let pat_eval = eval_value_with_context(bound, ctx); + Regex::new(pat_eval.text.as_ref()) + .ok() + .and_then(|regex| regex.captures(&hay_text)) + } + }; + if let Some(captures) = captures { + let mut values = Vec::with_capacity(captures.len()); + for idx in 0..captures.len() { + let text = captures.get(idx).map(|m| m.as_str()).unwrap_or(""); + values.push(text.to_string()); + } + ctx.set_captures(values); + bool_eval(true) + } else { + ctx.clear_captures(); + bool_eval(false) + } + } } } pub fn evaluate_truthy(value: &BoundValue, row: &R) -> bool +where + R: RowAccessor + ?Sized, +{ + let mut ctx = EvalContext::new(row); + evaluate_truthy_with_context(value, &mut ctx) +} + +pub fn evaluate_truthy_with_context<'a, R>( + value: &'a BoundValue, + ctx: &mut EvalContext<'a, R>, +) -> bool where R: RowAccessor + ?Sized, { match value { BoundValue::Columns(indices) => indices.iter().any(|idx| { - row.get(*idx) + ctx.row() + .get(*idx) .map(|text| !text.trim().is_empty()) .unwrap_or(false) }), _ => { - let eval = eval_value(value, row); + let eval = eval_value_with_context(value, ctx); if let Some(number) = eval.numeric { number != 0.0 } else { @@ -400,12 +606,17 @@ where } } -fn evaluate_compare(lhs: &BoundValue, op: CompareOp, rhs: &BoundValue, row: &R) -> bool +fn evaluate_compare<'a, R>( + lhs: &'a BoundValue, + op: CompareOp, + rhs: &'a BoundValue, + ctx: &mut EvalContext<'a, R>, +) -> bool where R: RowAccessor + ?Sized, { - let left = eval_value(lhs, row); - let right = eval_value(rhs, row); + let left = eval_value_with_context(lhs, ctx); + let right = eval_value_with_context(rhs, ctx); match op { CompareOp::Eq => { @@ -430,32 +641,37 @@ where } } -fn evaluate_regex(value: &BoundValue, pattern: &RegexPattern, invert: bool, row: &R) -> bool +fn evaluate_regex<'a, R>( + value: &'a BoundValue, + pattern: &'a RegexPattern, + invert: bool, + ctx: &mut EvalContext<'a, R>, +) -> bool where R: RowAccessor + ?Sized, { let is_match = match pattern { RegexPattern::Static(regex) => match value { - BoundValue::Column(idx) => regex.is_match(row.get(*idx).unwrap_or("")), + BoundValue::Column(idx) => regex.is_match(ctx.row().get(*idx).unwrap_or("")), BoundValue::Columns(indices) => indices .iter() - .any(|idx| regex.is_match(row.get(*idx).unwrap_or(""))), + .any(|idx| regex.is_match(ctx.row().get(*idx).unwrap_or(""))), _ => { - let hay = eval_value(value, row); + let hay = eval_value_with_context(value, ctx); regex.is_match(hay.text.as_ref()) } }, RegexPattern::Dynamic(bound) => { - let pat_eval = eval_value(bound, row); + let pat_eval = eval_value_with_context(bound, ctx); let pattern_text = pat_eval.text.as_ref(); if let Ok(regex) = Regex::new(pattern_text) { match value { - BoundValue::Column(idx) => regex.is_match(row.get(*idx).unwrap_or("")), + BoundValue::Column(idx) => regex.is_match(ctx.row().get(*idx).unwrap_or("")), BoundValue::Columns(indices) => indices .iter() - .any(|idx| regex.is_match(row.get(*idx).unwrap_or(""))), + .any(|idx| regex.is_match(ctx.row().get(*idx).unwrap_or(""))), _ => { - let hay = eval_value(value, row); + let hay = eval_value_with_context(value, ctx); regex.is_match(hay.text.as_ref()) } } @@ -528,6 +744,55 @@ fn bind_value(value: ValueExpr, headers: &[String], no_header: bool) -> Result { + let mut bound_branches = Vec::with_capacity(branches.len()); + for (cond, result) in branches { + let bound_cond = bind_expression(cond, headers, no_header)?; + let bound_result = bind_value(result, headers, no_header)?; + bound_branches.push((bound_cond, bound_result)); + } + let bound_default = match default { + Some(expr) => Some(Box::new(bind_value(*expr, headers, no_header)?)), + None => None, + }; + Ok(BoundValue::CaseWhen { + branches: bound_branches, + default: bound_default, + }) + } + ValueExpr::Switch { + target, + branches, + default, + } => { + let bound_target = bind_value(*target, headers, no_header)?; + let mut bound_branches = Vec::with_capacity(branches.len()); + for (values, result) in branches { + let mut bound_values = Vec::with_capacity(values.len()); + for value in values { + bound_values.push(bind_value(value, headers, no_header)?); + } + let bound_result = bind_value(result, headers, no_header)?; + bound_branches.push((bound_values, bound_result)); + } + let bound_default = match default { + Some(expr) => Some(Box::new(bind_value(*expr, headers, no_header)?)), + None => None, + }; + Ok(BoundValue::Switch { + target: Box::new(bound_target), + branches: bound_branches, + default: bound_default, + }) + } + ValueExpr::RegexCall(value, pattern) => { + let bound_value = bind_value(*value, headers, no_header)?; + let bound_pattern = bind_regex_pattern(*pattern, headers, no_header)?; + Ok(BoundValue::RegexCall { + value: Box::new(bound_value), + pattern: bound_pattern, + }) + } } } @@ -570,6 +835,14 @@ fn numeric_eval(value: f64) -> EvalValue<'static> { } } +fn bool_eval(value: bool) -> EvalValue<'static> { + if value { + numeric_eval(1.0) + } else { + numeric_eval(0.0) + } +} + fn empty_eval<'a>() -> EvalValue<'a> { EvalValue { text: Cow::Owned(String::new()), @@ -729,6 +1002,61 @@ mod tests { other => panic!("expected string literal, got {:?}", other), } } + + #[test] + fn case_when_selects_first_matching_branch() { + let value_expr = + parse_value_expression("case_when($1 > 5 -> \"high\", _ -> \"low\")").unwrap(); + let headers = vec!["score".to_string()]; + let bound = bind_value_expression(value_expr, &headers, false).unwrap(); + + let row_high = vec!["6".to_string()]; + assert_eq!(eval_value(&bound, &row_high).text.as_ref(), "high"); + + let row_low = vec!["2".to_string()]; + assert_eq!(eval_value(&bound, &row_low).text.as_ref(), "low"); + } + + #[test] + fn regex_call_populates_capture_groups() { + let value_expr = + parse_value_expression("case_when(re($1, \"^ERR(\\\\d+)$\") -> $1, _ -> \"nomatch\")") + .unwrap(); + let headers = vec!["sample".to_string()]; + let bound = bind_value_expression(value_expr, &headers, false).unwrap(); + + let matched = vec!["ERR123".to_string()]; + assert_eq!(eval_value(&bound, &matched).text.as_ref(), "123"); + + let unmatched = vec!["SRR55".to_string()]; + assert_eq!(eval_value(&bound, &unmatched).text.as_ref(), "nomatch"); + } + + #[test] + fn switch_maps_values_to_labels() { + let value_expr = parse_value_expression( + "switch($1, [\"DE\",\"FR\"] -> \"EU\", [\"US\",\"CA\"] -> \"NA\", _ -> \"Other\")", + ) + .unwrap(); + let headers = vec!["country".to_string()]; + let bound = bind_value_expression(value_expr, &headers, false).unwrap(); + + let row_eu = vec!["DE".to_string()]; + assert_eq!(eval_value(&bound, &row_eu).text.as_ref(), "EU"); + + let row_other = vec!["JP".to_string()]; + assert_eq!(eval_value(&bound, &row_other).text.as_ref(), "Other"); + } + + #[test] + fn len_function_counts_characters() { + let value_expr = parse_value_expression("len($1)").unwrap(); + let headers = vec!["text".to_string()]; + let bound = bind_value_expression(value_expr, &headers, false).unwrap(); + let row = vec!["hello".to_string()]; + let result = eval_value(&bound, &row); + assert_eq!(result.numeric, Some(5.0)); + } } impl<'a> Lexer<'a> { @@ -812,13 +1140,30 @@ impl<'a> Lexer<'a> { self.pos += 1; Ok(Some(Token::RParen)) } + b'[' => { + self.pos += 1; + Ok(Some(Token::LBracket)) + } + b']' => { + self.pos += 1; + Ok(Some(Token::RBracket)) + } + b',' => { + self.pos += 1; + Ok(Some(Token::Comma)) + } b'+' => { self.pos += 1; Ok(Some(Token::Plus)) } b'-' => { - self.pos += 1; - Ok(Some(Token::Minus)) + if self.peek_char(1) == Some(b'>') { + self.pos += 2; + Ok(Some(Token::Arrow)) + } else { + self.pos += 1; + Ok(Some(Token::Minus)) + } } b'*' => { self.pos += 1; @@ -875,13 +1220,40 @@ impl<'a> Lexer<'a> { continue; } if c.is_ascii_alphanumeric() || matches!(c, b'_' | b'.' | b',' | b':') { - if c == b',' || c == b':' { + if c == b',' { + let mut idx = self.pos + 1; + while idx < self.chars.len() && self.chars[idx].is_ascii_whitespace() { + idx += 1; + } + let next = self.chars.get(idx).copied(); + let is_selector_start = next.map_or(false, |next_char| { + if next_char == b'$' || next_char == b'{' { + return true; + } + if next_char == b'_' { + let mut lookahead = idx + 1; + while lookahead < self.chars.len() + && self.chars[lookahead].is_ascii_whitespace() + { + lookahead += 1; + } + if lookahead < self.chars.len() && self.chars[lookahead] == b'-' { + return false; + } + } + next_char.is_ascii_alphanumeric() || next_char == b'_' || next_char == b'.' + }); + if is_selector_start { + has_range_syntax = true; + is_numeric = false; + self.pos += 1; + continue; + } else { + break; + } + } else if c == b':' { has_range_syntax = true; is_numeric = false; - } else if !c.is_ascii_digit() { - is_numeric = false; - } - if c == b',' || c == b':' { self.pos += 1; continue; } @@ -1256,11 +1628,29 @@ impl Parser { name ); } - let argument = self.parse_arith()?; - if !self.consume_token(TokenKind::RParen) { - bail!("missing ')' after function call"); + let lower = name.to_ascii_lowercase(); + if lower == "case_when" { + let expr = self.parse_case_when_function()?; + return Ok(expr); + } + if lower == "switch" { + let expr = self.parse_switch_function()?; + return Ok(expr); + } + if lower == "re" { + let mut args = self.parse_function_arguments()?; + if args.len() != 2 { + bail!("re() expects two arguments: value, pattern"); + } + let pattern = args.pop().unwrap(); + let value = args.pop().unwrap(); + return Ok(ValueExpr::RegexCall(Box::new(value), Box::new(pattern))); } if let Some(kind) = try_parse_aggregate_kind(&name)? { + let argument = self.parse_arith()?; + if !self.consume_token(TokenKind::RParen) { + bail!("missing ')' after function call"); + } let selectors = match argument { ValueExpr::Column(selector) => vec![selector], ValueExpr::Columns(list) => list, @@ -1274,6 +1664,11 @@ impl Parser { }; return Ok(ValueExpr::Aggregate(AggregateSpecExpr { kind, selectors })); } + let mut args = self.parse_function_arguments()?; + if args.len() != 1 { + bail!("function '{}' expects exactly one argument", name); + } + let argument = args.pop().unwrap(); let func = FunctionName::from_ident(&name)?; Ok(ValueExpr::Function(func, Box::new(argument))) } else { @@ -1292,6 +1687,210 @@ impl Parser { } } + fn parse_function_arguments(&mut self) -> Result> { + let mut args = Vec::new(); + if self.consume_token(TokenKind::RParen) { + return Ok(args); + } + loop { + let expr = self.parse_arith()?; + args.push(expr); + if self.consume_token(TokenKind::Comma) { + continue; + } else if self.consume_token(TokenKind::RParen) { + break; + } else { + bail!("expected ',' or ')' after function argument"); + } + } + Ok(args) + } + + fn parse_case_when_function(&mut self) -> Result { + let mut branches = Vec::new(); + let mut default = None; + if self.consume_token(TokenKind::RParen) { + bail!("case_when requires at least one branch"); + } + loop { + if let Some(Token::Ident(name)) = self.peek_token().cloned() { + if name == "_" { + self.pos += 1; + if !self.consume_token(TokenKind::Arrow) { + bail!("case_when default branch must use '->'"); + } + let result = self.parse_case_result_value()?; + if default.is_some() { + bail!("case_when default branch specified more than once"); + } + default = Some(Box::new(result)); + if !self.consume_token(TokenKind::RParen) { + bail!("case_when default branch must be last"); + } + break; + } + } + let condition = self.parse_case_condition()?; + if !self.consume_token(TokenKind::Arrow) { + bail!("case_when branches must use '->'"); + } + let result = self.parse_case_result_value()?; + branches.push((condition, result)); + if self.consume_token(TokenKind::Comma) { + continue; + } else if self.consume_token(TokenKind::RParen) { + break; + } else { + bail!("expected ',' or ')' after case_when branch"); + } + } + if branches.is_empty() && default.is_none() { + bail!("case_when requires at least one branch"); + } + Ok(ValueExpr::CaseWhen(branches, default)) + } + + fn parse_case_condition(&mut self) -> Result { + let start = self.pos; + let mut depth = 0; + let mut idx = self.pos; + while idx < self.tokens.len() { + match self.tokens[idx] { + Token::LParen | Token::LBracket => depth += 1, + Token::RParen | Token::RBracket => { + if depth == 0 { + break; + } else { + depth -= 1; + } + } + Token::Arrow if depth == 0 => { + let slice = self.tokens[start..idx].to_vec(); + let mut parser = Parser::new(slice); + let expr = parser.parse_expr()?; + if parser.has_more() { + bail!("unexpected token in case_when condition"); + } + self.pos = idx; + return Ok(expr); + } + _ => {} + } + idx += 1; + } + bail!("case_when branches must use '->'") + } + + fn parse_case_result_value(&mut self) -> Result { + let start = self.pos; + let mut depth = 0; + let mut idx = self.pos; + while idx < self.tokens.len() { + match self.tokens[idx] { + Token::LParen | Token::LBracket => depth += 1, + Token::RParen => { + if depth == 0 { + break; + } else { + depth -= 1; + } + } + Token::Comma if depth == 0 => { + break; + } + _ => {} + } + idx += 1; + } + if idx == start { + bail!("case_when result must not be empty"); + } + let slice = self.tokens[start..idx].to_vec(); + let mut parser = Parser::new(slice); + let value = parser.parse_arith()?; + if parser.has_more() { + bail!("unexpected token in case_when result"); + } + self.pos = idx; + Ok(value) + } + + fn parse_switch_function(&mut self) -> Result { + let target = self.parse_arith()?; + if !self.consume_token(TokenKind::Comma) { + bail!("switch() requires a comma after the target expression"); + } + let mut branches = Vec::new(); + let mut default = None; + if self.consume_token(TokenKind::RParen) { + bail!("switch requires at least one branch"); + } + loop { + if let Some(Token::Ident(name)) = self.peek_token().cloned() { + if name == "_" { + self.pos += 1; + if !self.consume_token(TokenKind::Arrow) { + bail!("switch default branch must use '->'"); + } + let result = self.parse_arith()?; + if default.is_some() { + bail!("switch default branch specified more than once"); + } + default = Some(Box::new(result)); + if !self.consume_token(TokenKind::RParen) { + bail!("switch default branch must be last"); + } + break; + } + } + let values = self.parse_switch_values()?; + if !self.consume_token(TokenKind::Arrow) { + bail!("switch branches must use '->'"); + } + let result = self.parse_arith()?; + branches.push((values, result)); + if self.consume_token(TokenKind::Comma) { + continue; + } else if self.consume_token(TokenKind::RParen) { + break; + } else { + bail!("expected ',' or ')' after switch branch"); + } + } + if branches.is_empty() && default.is_none() { + bail!("switch requires at least one branch"); + } + Ok(ValueExpr::Switch { + target: Box::new(target), + branches, + default, + }) + } + + fn parse_switch_values(&mut self) -> Result> { + if self.consume_token(TokenKind::LBracket) { + let mut values = Vec::new(); + if self.consume_token(TokenKind::RBracket) { + bail!("switch value list must not be empty"); + } + loop { + let value = self.parse_arith()?; + values.push(value); + if self.consume_token(TokenKind::Comma) { + continue; + } else if self.consume_token(TokenKind::RBracket) { + break; + } else { + bail!("expected ',' or ']' in switch value list"); + } + } + Ok(values) + } else { + let value = self.parse_arith()?; + Ok(vec![value]) + } + } + fn consume_compare(&mut self) -> Option { if let Some(Token::Compare(op)) = self.peek_token().cloned() { self.pos += 1; @@ -1308,6 +1907,10 @@ impl Parser { (TokenKind::Not, Some(Token::Not)) => true, (TokenKind::LParen, Some(Token::LParen)) => true, (TokenKind::RParen, Some(Token::RParen)) => true, + (TokenKind::LBracket, Some(Token::LBracket)) => true, + (TokenKind::RBracket, Some(Token::RBracket)) => true, + (TokenKind::Comma, Some(Token::Comma)) => true, + (TokenKind::Arrow, Some(Token::Arrow)) => true, (TokenKind::Plus, Some(Token::Plus)) => true, (TokenKind::Minus, Some(Token::Minus)) => true, (TokenKind::Star, Some(Token::Star)) => true, @@ -1350,6 +1953,10 @@ enum TokenKind { Not, LParen, RParen, + LBracket, + RBracket, + Comma, + Arrow, Plus, Minus, Star, From 1a24d73695a8f8a4918aa8de53472c837931291d Mon Sep 17 00:00:00 2001 From: "Z.-L. Deng" Date: Thu, 9 Oct 2025 09:08:53 +0200 Subject: [PATCH 3/9] Fix case_when branch parsing and add regression tests --- Cargo.lock | 2 +- Cargo.toml | 2 +- README.md | 31 ++ src/expression.rs | 757 ++++++++++++++++++++++++++++++++++++++++++---- src/mutate.rs | 8 + 5 files changed, 738 insertions(+), 62 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5f4ca84..6d80c91 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -603,7 +603,7 @@ dependencies = [ [[package]] name = "tsvkit" -version = "0.9.4" +version = "0.9.5" dependencies = [ "anyhow", "calamine", diff --git a/Cargo.toml b/Cargo.toml index 8e88aed..9be5e12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tsvkit" -version = "0.9.4" +version = "0.9.5" edition = "2024" [dependencies] diff --git a/README.md b/README.md index e7c23b8..3be8153 100644 --- a/README.md +++ b/README.md @@ -165,9 +165,27 @@ The same expression language powers `filter -e`, `mutate -e name=EXPR`, and rege | `ln(expr)` | Natural logarithm | | `log(expr)` / `log10(expr)` | Base-10 logarithm | | `log2(expr)` | Base-2 logarithm | +| `len(expr)` | Character count using Unicode code points. | +| `is_na(expr)` | Returns `1` when the expression is blank/`NA`/`NaN`, otherwise `0`. | Functions accept column references (`abs($purity - 1)`), constants, or subexpressions. Empty or non-numeric values yield blanks. +**Conditional and regex helpers** + +- `case_when(condition -> result, ..., _ -> default)` evaluates each boolean condition in order and returns the matching result. The final `_` branch acts as the default. +- `switch(value, [match1, match2] -> result, ..., _ -> default)` compares `value` to one or more literal matches (strings or numbers) and returns the corresponding result. +- `re(value, pattern)` evaluates a regex against `value`, returning `1` or `0`. When the pattern matches, capture groups become available as `$1`, `$2`, etc. for the remainder of the expression (use `$0`-style numeric selectors sparingly when you rely on captures). + +Example: + +``` +case_when( + re($sample, "^ERR(\\d+)$") -> $1, + re($sample, "^SRR") -> "SRA", + _ -> $sample +) +``` + **Row-wise aggregation helpers** Available within `mutate` expressions via functions such as `sum($col1:$col5)`; see the [Mutate](#mutate) section for the full list. @@ -260,6 +278,19 @@ tsvkit mutate \ examples/cytokines.tsv ``` +Use `case_when`, `switch`, and the `re()` helper for richer branching logic and regex capture reuse: + +```bash +tsvkit mutate \ + -e 'label = case_when( + re($sample, "^ERR(\d+)$") -> $1, + re($sample, "^SRR") -> "SRA", + _ -> $sample + )' \ + -e 'bucket = case_when(len($clean) == 0 -> "empty", len($clean) < 5 -> "short", _ -> "long")' \ + examples/samples.tsv +``` + Apply in-place edits with the sed-style form: ```bash diff --git a/src/expression.rs b/src/expression.rs index 34cacfe..12535fa 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -28,6 +28,13 @@ pub enum ValueExpr { Binary(BinaryOp, Box, Box), Function(FunctionName, Box), Aggregate(AggregateSpecExpr), + CaseWhen(Vec<(Expr, ValueExpr)>, Option>), + Switch { + target: Box, + branches: Vec<(Vec, ValueExpr)>, + default: Option>, + }, + RegexCall(Box, Box), } #[derive(Debug, Clone)] @@ -59,6 +66,8 @@ pub enum FunctionName { Ln, Log10, Log2, + Len, + IsNa, } impl FunctionName { @@ -71,8 +80,10 @@ impl FunctionName { "ln" => Ok(FunctionName::Ln), "log" | "log10" => Ok(FunctionName::Log10), "log2" => Ok(FunctionName::Log2), + "len" => Ok(FunctionName::Len), + "is_na" => Ok(FunctionName::IsNa), other => bail!( - "unsupported function '{}': try abs, sqrt, exp, exp2, ln, log, log10, log2", + "unsupported function '{}': try abs, sqrt, exp, exp2, ln, log, log10, log2, len, is_na", other ), } @@ -100,6 +111,10 @@ enum Token { Ident(String), LParen, RParen, + LBracket, + RBracket, + Comma, + Arrow, And, Or, Not, @@ -212,6 +227,19 @@ pub enum BoundValue { Binary(BinaryOp, Box, Box), Function(FunctionName, Box), Aggregate(BoundAggregate), + CaseWhen { + branches: Vec<(BoundExpr, BoundValue)>, + default: Option>, + }, + Switch { + target: Box, + branches: Vec<(Vec, BoundValue)>, + default: Option>, + }, + RegexCall { + value: Box, + pattern: RegexPattern, + }, } #[derive(Debug, Clone)] @@ -220,7 +248,7 @@ pub struct BoundAggregate { pub columns: Vec, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum BoundExpr { Or(Box, Box), And(Box, Box), @@ -234,7 +262,7 @@ pub enum BoundExpr { Value(BoundValue), } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) enum RegexPattern { Static(Arc), Dynamic(Box), @@ -245,31 +273,105 @@ pub struct EvalValue<'a> { pub numeric: Option, } +pub struct EvalContext<'a, R> +where + R: RowAccessor + ?Sized, +{ + row: &'a R, + regex_captures: Option>, +} + +impl<'a, R> EvalContext<'a, R> +where + R: RowAccessor + ?Sized, +{ + pub fn new(row: &'a R) -> Self { + EvalContext { + row, + regex_captures: None, + } + } + + fn row(&self) -> &'a R { + self.row + } + + fn clear_captures(&mut self) { + self.regex_captures = None; + } + + fn set_captures(&mut self, captures: Vec) { + self.regex_captures = Some(captures); + } + + fn take_captures(&mut self) -> Option> { + self.regex_captures.take() + } + + fn restore_captures(&mut self, captures: Option>) { + self.regex_captures = captures; + } +} + pub fn evaluate(expr: &BoundExpr, row: &R) -> bool +where + R: RowAccessor + ?Sized, +{ + let mut ctx = EvalContext::new(row); + evaluate_with_context(expr, &mut ctx) +} + +fn evaluate_with_context<'a, R>(expr: &'a BoundExpr, ctx: &mut EvalContext<'a, R>) -> bool where R: RowAccessor + ?Sized, { match expr { - BoundExpr::Or(lhs, rhs) => evaluate(lhs, row) || evaluate(rhs, row), - BoundExpr::And(lhs, rhs) => evaluate(lhs, row) && evaluate(rhs, row), - BoundExpr::Not(inner) => !evaluate(inner, row), - BoundExpr::Compare(lhs, op, rhs) => evaluate_compare(lhs, *op, rhs, row), + BoundExpr::Or(lhs, rhs) => { + evaluate_with_context(lhs, ctx) || evaluate_with_context(rhs, ctx) + } + BoundExpr::And(lhs, rhs) => { + evaluate_with_context(lhs, ctx) && evaluate_with_context(rhs, ctx) + } + BoundExpr::Not(inner) => !evaluate_with_context(inner, ctx), + BoundExpr::Compare(lhs, op, rhs) => evaluate_compare(lhs, *op, rhs, ctx), BoundExpr::RegexMatch { value, pattern, invert, - } => evaluate_regex(value, pattern, *invert, row), - BoundExpr::Value(value) => evaluate_truthy(value, row), + } => evaluate_regex(value, pattern, *invert, ctx), + BoundExpr::Value(value) => evaluate_truthy_with_context(value, ctx), } } pub fn eval_value<'a, R>(value: &'a BoundValue, row: &'a R) -> EvalValue<'a> +where + R: RowAccessor + ?Sized, +{ + let mut ctx = EvalContext::new(row); + eval_value_with_context(value, &mut ctx) +} + +pub fn eval_value_with_context<'a, R>( + value: &'a BoundValue, + ctx: &mut EvalContext<'a, R>, +) -> EvalValue<'a> where R: RowAccessor + ?Sized, { match value { BoundValue::Column(idx) => { - let text = row.get(*idx).unwrap_or(""); + if let Some(captures) = ctx.regex_captures.as_ref() { + let capture_idx = idx + 1; + if capture_idx < captures.len() { + let capture = captures[capture_idx].clone(); + let numeric = parse_float(&capture); + return EvalValue { + text: Cow::Owned(capture), + numeric, + }; + } + } + let text = ctx.row().get(*idx).unwrap_or(""); EvalValue { text: Cow::Borrowed(text), numeric: parse_float(text), @@ -282,7 +384,7 @@ where let mut combined = String::new(); let mut numeric = None; for (pos, idx) in indices.iter().enumerate() { - let text = row.get(*idx).unwrap_or(""); + let text = ctx.row().get(*idx).unwrap_or(""); if pos > 0 { combined.push('\t'); } @@ -302,7 +404,7 @@ where }, BoundValue::Number(number) => numeric_eval(*number), BoundValue::Unary(op, inner) => { - let inner_eval = eval_value(inner, row); + let inner_eval = eval_value_with_context(inner, ctx); if let Some(val) = inner_eval.numeric { numeric_eval(match op { UnaryOp::Neg => -val, @@ -312,8 +414,8 @@ where } } BoundValue::Binary(op, left, right) => { - let left_eval = eval_value(left, row); - let right_eval = eval_value(right, row); + let left_eval = eval_value_with_context(left, ctx); + let right_eval = eval_value_with_context(right, ctx); match (left_eval.numeric, right_eval.numeric) { (Some(a), Some(b)) => match op { BinaryOp::Add => numeric_eval(a + b), @@ -339,36 +441,58 @@ where } } BoundValue::Function(func, inner) => { - let inner_eval = eval_value(inner, row); - if let Some(val) = inner_eval.numeric { - let result = match func { - FunctionName::Abs => Some(val.abs()), - FunctionName::Sqrt => { - let value = val.sqrt(); - value.is_finite().then_some(value) - } - FunctionName::Exp => { - let value = val.exp(); - value.is_finite().then_some(value) - } - FunctionName::Exp2 => { - let value = val.exp2(); - value.is_finite().then_some(value) + let inner_eval = eval_value_with_context(inner, ctx); + match func { + FunctionName::Abs + | FunctionName::Sqrt + | FunctionName::Exp + | FunctionName::Exp2 + | FunctionName::Ln + | FunctionName::Log10 + | FunctionName::Log2 => { + if let Some(val) = inner_eval.numeric { + let result = match func { + FunctionName::Abs => Some(val.abs()), + FunctionName::Sqrt => { + let value = val.sqrt(); + value.is_finite().then_some(value) + } + FunctionName::Exp => { + let value = val.exp(); + value.is_finite().then_some(value) + } + FunctionName::Exp2 => { + let value = val.exp2(); + value.is_finite().then_some(value) + } + FunctionName::Ln => (val > 0.0).then(|| val.ln()), + FunctionName::Log10 => (val > 0.0).then(|| val.log10()), + FunctionName::Log2 => (val > 0.0).then(|| val.log2()), + _ => None, + }; + result.map(numeric_eval).unwrap_or_else(empty_eval) + } else { + empty_eval() } - FunctionName::Ln => (val > 0.0).then(|| val.ln()), - FunctionName::Log10 => (val > 0.0).then(|| val.log10()), - FunctionName::Log2 => (val > 0.0).then(|| val.log2()), - }; - result.map(numeric_eval).unwrap_or_else(empty_eval) - } else { - empty_eval() + } + FunctionName::Len => { + let len = inner_eval.text.chars().count() as f64; + numeric_eval(len) + } + FunctionName::IsNa => { + let text = inner_eval.text.as_ref().trim(); + let is_na = text.is_empty() + || text.eq_ignore_ascii_case("na") + || text.eq_ignore_ascii_case("nan"); + bool_eval(is_na) + } } } BoundValue::Aggregate(spec) => { let values = spec .columns .iter() - .map(|&idx| row.get(idx).unwrap_or("")) + .map(|&idx| ctx.row().get(idx).unwrap_or("")) .collect::>(); let result = evaluate_row_aggregate(&spec.kind, &values); EvalValue { @@ -376,21 +500,103 @@ where numeric: result.numeric, } } + BoundValue::CaseWhen { branches, default } => { + for (cond, result) in branches { + ctx.clear_captures(); + if evaluate_with_context(cond, ctx) { + return eval_value_with_context(result, ctx); + } + } + ctx.clear_captures(); + if let Some(default) = default { + eval_value_with_context(default, ctx) + } else { + empty_eval() + } + } + BoundValue::Switch { + target, + branches, + default, + } => { + let target_eval = eval_value_with_context(target, ctx); + let target_numeric = target_eval.numeric; + let target_text = target_eval.text.into_owned(); + ctx.clear_captures(); + for (values, result) in branches { + for value in values { + let saved = ctx.take_captures(); + let candidate = eval_value_with_context(value, ctx); + ctx.restore_captures(saved); + let is_match = match (target_numeric, candidate.numeric) { + (Some(a), Some(b)) => a == b, + _ => target_text == candidate.text.as_ref(), + }; + if is_match { + ctx.clear_captures(); + return eval_value_with_context(result, ctx); + } + } + } + ctx.clear_captures(); + if let Some(default) = default { + eval_value_with_context(default, ctx) + } else { + empty_eval() + } + } + BoundValue::RegexCall { value, pattern } => { + let hay = eval_value_with_context(value, ctx); + let hay_text = hay.text.into_owned(); + let captures = match pattern { + RegexPattern::Static(regex) => regex.captures(&hay_text), + RegexPattern::Dynamic(bound) => { + let pat_eval = eval_value_with_context(bound, ctx); + Regex::new(pat_eval.text.as_ref()) + .ok() + .and_then(|regex| regex.captures(&hay_text)) + } + }; + if let Some(captures) = captures { + let mut values = Vec::with_capacity(captures.len()); + for idx in 0..captures.len() { + let text = captures.get(idx).map(|m| m.as_str()).unwrap_or(""); + values.push(text.to_string()); + } + ctx.set_captures(values); + bool_eval(true) + } else { + ctx.clear_captures(); + bool_eval(false) + } + } } } pub fn evaluate_truthy(value: &BoundValue, row: &R) -> bool +where + R: RowAccessor + ?Sized, +{ + let mut ctx = EvalContext::new(row); + evaluate_truthy_with_context(value, &mut ctx) +} + +pub fn evaluate_truthy_with_context<'a, R>( + value: &'a BoundValue, + ctx: &mut EvalContext<'a, R>, +) -> bool where R: RowAccessor + ?Sized, { match value { BoundValue::Columns(indices) => indices.iter().any(|idx| { - row.get(*idx) + ctx.row() + .get(*idx) .map(|text| !text.trim().is_empty()) .unwrap_or(false) }), _ => { - let eval = eval_value(value, row); + let eval = eval_value_with_context(value, ctx); if let Some(number) = eval.numeric { number != 0.0 } else { @@ -400,12 +606,17 @@ where } } -fn evaluate_compare(lhs: &BoundValue, op: CompareOp, rhs: &BoundValue, row: &R) -> bool +fn evaluate_compare<'a, R>( + lhs: &'a BoundValue, + op: CompareOp, + rhs: &'a BoundValue, + ctx: &mut EvalContext<'a, R>, +) -> bool where R: RowAccessor + ?Sized, { - let left = eval_value(lhs, row); - let right = eval_value(rhs, row); + let left = eval_value_with_context(lhs, ctx); + let right = eval_value_with_context(rhs, ctx); match op { CompareOp::Eq => { @@ -430,32 +641,37 @@ where } } -fn evaluate_regex(value: &BoundValue, pattern: &RegexPattern, invert: bool, row: &R) -> bool +fn evaluate_regex<'a, R>( + value: &'a BoundValue, + pattern: &'a RegexPattern, + invert: bool, + ctx: &mut EvalContext<'a, R>, +) -> bool where R: RowAccessor + ?Sized, { let is_match = match pattern { RegexPattern::Static(regex) => match value { - BoundValue::Column(idx) => regex.is_match(row.get(*idx).unwrap_or("")), + BoundValue::Column(idx) => regex.is_match(ctx.row().get(*idx).unwrap_or("")), BoundValue::Columns(indices) => indices .iter() - .any(|idx| regex.is_match(row.get(*idx).unwrap_or(""))), + .any(|idx| regex.is_match(ctx.row().get(*idx).unwrap_or(""))), _ => { - let hay = eval_value(value, row); + let hay = eval_value_with_context(value, ctx); regex.is_match(hay.text.as_ref()) } }, RegexPattern::Dynamic(bound) => { - let pat_eval = eval_value(bound, row); + let pat_eval = eval_value_with_context(bound, ctx); let pattern_text = pat_eval.text.as_ref(); if let Ok(regex) = Regex::new(pattern_text) { match value { - BoundValue::Column(idx) => regex.is_match(row.get(*idx).unwrap_or("")), + BoundValue::Column(idx) => regex.is_match(ctx.row().get(*idx).unwrap_or("")), BoundValue::Columns(indices) => indices .iter() - .any(|idx| regex.is_match(row.get(*idx).unwrap_or(""))), + .any(|idx| regex.is_match(ctx.row().get(*idx).unwrap_or(""))), _ => { - let hay = eval_value(value, row); + let hay = eval_value_with_context(value, ctx); regex.is_match(hay.text.as_ref()) } } @@ -528,6 +744,55 @@ fn bind_value(value: ValueExpr, headers: &[String], no_header: bool) -> Result { + let mut bound_branches = Vec::with_capacity(branches.len()); + for (cond, result) in branches { + let bound_cond = bind_expression(cond, headers, no_header)?; + let bound_result = bind_value(result, headers, no_header)?; + bound_branches.push((bound_cond, bound_result)); + } + let bound_default = match default { + Some(expr) => Some(Box::new(bind_value(*expr, headers, no_header)?)), + None => None, + }; + Ok(BoundValue::CaseWhen { + branches: bound_branches, + default: bound_default, + }) + } + ValueExpr::Switch { + target, + branches, + default, + } => { + let bound_target = bind_value(*target, headers, no_header)?; + let mut bound_branches = Vec::with_capacity(branches.len()); + for (values, result) in branches { + let mut bound_values = Vec::with_capacity(values.len()); + for value in values { + bound_values.push(bind_value(value, headers, no_header)?); + } + let bound_result = bind_value(result, headers, no_header)?; + bound_branches.push((bound_values, bound_result)); + } + let bound_default = match default { + Some(expr) => Some(Box::new(bind_value(*expr, headers, no_header)?)), + None => None, + }; + Ok(BoundValue::Switch { + target: Box::new(bound_target), + branches: bound_branches, + default: bound_default, + }) + } + ValueExpr::RegexCall(value, pattern) => { + let bound_value = bind_value(*value, headers, no_header)?; + let bound_pattern = bind_regex_pattern(*pattern, headers, no_header)?; + Ok(BoundValue::RegexCall { + value: Box::new(bound_value), + pattern: bound_pattern, + }) + } } } @@ -570,6 +835,14 @@ fn numeric_eval(value: f64) -> EvalValue<'static> { } } +fn bool_eval(value: bool) -> EvalValue<'static> { + if value { + numeric_eval(1.0) + } else { + numeric_eval(0.0) + } +} + fn empty_eval<'a>() -> EvalValue<'a> { EvalValue { text: Cow::Owned(String::new()), @@ -729,6 +1002,78 @@ mod tests { other => panic!("expected string literal, got {:?}", other), } } + + #[test] + fn case_when_selects_first_matching_branch() { + let value_expr = + parse_value_expression("case_when($1 > 5 -> \"high\", _ -> \"low\")").unwrap(); + let headers = vec!["score".to_string()]; + let bound = bind_value_expression(value_expr, &headers, false).unwrap(); + + let row_high = vec!["6".to_string()]; + assert_eq!(eval_value(&bound, &row_high).text.as_ref(), "high"); + + let row_low = vec!["2".to_string()]; + assert_eq!(eval_value(&bound, &row_low).text.as_ref(), "low"); + } + + #[test] + fn regex_call_populates_capture_groups() { + let value_expr = + parse_value_expression("case_when(re($1, \"^ERR(\\\\d+)$\") -> $1, _ -> \"nomatch\")") + .unwrap(); + let headers = vec!["sample".to_string()]; + let bound = bind_value_expression(value_expr, &headers, false).unwrap(); + + let matched = vec!["ERR123".to_string()]; + assert_eq!(eval_value(&bound, &matched).text.as_ref(), "123"); + + let unmatched = vec!["SRR55".to_string()]; + assert_eq!(eval_value(&bound, &unmatched).text.as_ref(), "nomatch"); + } + + #[test] + fn case_when_with_named_columns_and_regex_capture() { + let source = "case_when(\n re($sample_id, \"^ERR(\\\\d+)$\") -> $1,\n re($sample_id, \"^SRR\") -> \"SRA\",\n _ -> $sample_id\n )"; + let expr = parse_value_expression(source).unwrap(); + let headers = vec!["sample_id".to_string()]; + let bound = bind_value_expression(expr, &headers, false).unwrap(); + + let row_err = vec!["ERR1234".to_string()]; + assert_eq!(eval_value(&bound, &row_err).text.as_ref(), "1234"); + + let row_srr = vec!["SRR9000".to_string()]; + assert_eq!(eval_value(&bound, &row_srr).text.as_ref(), "SRA"); + + let row_other = vec!["OTHER".to_string()]; + assert_eq!(eval_value(&bound, &row_other).text.as_ref(), "OTHER"); + } + + #[test] + fn switch_maps_values_to_labels() { + let value_expr = parse_value_expression( + "switch($1, [\"DE\",\"FR\"] -> \"EU\", [\"US\",\"CA\"] -> \"NA\", _ -> \"Other\")", + ) + .unwrap(); + let headers = vec!["country".to_string()]; + let bound = bind_value_expression(value_expr, &headers, false).unwrap(); + + let row_eu = vec!["DE".to_string()]; + assert_eq!(eval_value(&bound, &row_eu).text.as_ref(), "EU"); + + let row_other = vec!["JP".to_string()]; + assert_eq!(eval_value(&bound, &row_other).text.as_ref(), "Other"); + } + + #[test] + fn len_function_counts_characters() { + let value_expr = parse_value_expression("len($1)").unwrap(); + let headers = vec!["text".to_string()]; + let bound = bind_value_expression(value_expr, &headers, false).unwrap(); + let row = vec!["hello".to_string()]; + let result = eval_value(&bound, &row); + assert_eq!(result.numeric, Some(5.0)); + } } impl<'a> Lexer<'a> { @@ -812,13 +1157,30 @@ impl<'a> Lexer<'a> { self.pos += 1; Ok(Some(Token::RParen)) } + b'[' => { + self.pos += 1; + Ok(Some(Token::LBracket)) + } + b']' => { + self.pos += 1; + Ok(Some(Token::RBracket)) + } + b',' => { + self.pos += 1; + Ok(Some(Token::Comma)) + } b'+' => { self.pos += 1; Ok(Some(Token::Plus)) } b'-' => { - self.pos += 1; - Ok(Some(Token::Minus)) + if self.peek_char(1) == Some(b'>') { + self.pos += 2; + Ok(Some(Token::Arrow)) + } else { + self.pos += 1; + Ok(Some(Token::Minus)) + } } b'*' => { self.pos += 1; @@ -875,13 +1237,53 @@ impl<'a> Lexer<'a> { continue; } if c.is_ascii_alphanumeric() || matches!(c, b'_' | b'.' | b',' | b':') { - if c == b',' || c == b':' { + if c == b',' { + let mut idx = self.pos + 1; + while idx < self.chars.len() && self.chars[idx].is_ascii_whitespace() { + idx += 1; + } + let next = self.chars.get(idx).copied(); + let is_selector_start = next.map_or(false, |next_char| { + if matches!(next_char, b'$' | b'{') { + return true; + } + if next_char == b'_' { + let mut lookahead = idx + 1; + while lookahead < self.chars.len() + && self.chars[lookahead].is_ascii_whitespace() + { + lookahead += 1; + } + if lookahead < self.chars.len() && self.chars[lookahead] == b'-' { + return false; + } + } + if next_char == b'-' { + let mut lookahead = idx + 1; + while lookahead < self.chars.len() + && self.chars[lookahead].is_ascii_whitespace() + { + lookahead += 1; + } + if lookahead < self.chars.len() { + let following = self.chars[lookahead]; + return matches!(following, b'$' | b'{' | b'0'..=b'9'); + } + return false; + } + next_char.is_ascii_digit() + }); + if is_selector_start { + has_range_syntax = true; + is_numeric = false; + self.pos += 1; + continue; + } else { + break; + } + } else if c == b':' { has_range_syntax = true; is_numeric = false; - } else if !c.is_ascii_digit() { - is_numeric = false; - } - if c == b',' || c == b':' { self.pos += 1; continue; } @@ -1256,11 +1658,29 @@ impl Parser { name ); } - let argument = self.parse_arith()?; - if !self.consume_token(TokenKind::RParen) { - bail!("missing ')' after function call"); + let lower = name.to_ascii_lowercase(); + if lower == "case_when" { + let expr = self.parse_case_when_function()?; + return Ok(expr); + } + if lower == "switch" { + let expr = self.parse_switch_function()?; + return Ok(expr); + } + if lower == "re" { + let mut args = self.parse_function_arguments()?; + if args.len() != 2 { + bail!("re() expects two arguments: value, pattern"); + } + let pattern = args.pop().unwrap(); + let value = args.pop().unwrap(); + return Ok(ValueExpr::RegexCall(Box::new(value), Box::new(pattern))); } if let Some(kind) = try_parse_aggregate_kind(&name)? { + let argument = self.parse_arith()?; + if !self.consume_token(TokenKind::RParen) { + bail!("missing ')' after function call"); + } let selectors = match argument { ValueExpr::Column(selector) => vec![selector], ValueExpr::Columns(list) => list, @@ -1274,6 +1694,11 @@ impl Parser { }; return Ok(ValueExpr::Aggregate(AggregateSpecExpr { kind, selectors })); } + let mut args = self.parse_function_arguments()?; + if args.len() != 1 { + bail!("function '{}' expects exactly one argument", name); + } + let argument = args.pop().unwrap(); let func = FunctionName::from_ident(&name)?; Ok(ValueExpr::Function(func, Box::new(argument))) } else { @@ -1292,6 +1717,210 @@ impl Parser { } } + fn parse_function_arguments(&mut self) -> Result> { + let mut args = Vec::new(); + if self.consume_token(TokenKind::RParen) { + return Ok(args); + } + loop { + let expr = self.parse_arith()?; + args.push(expr); + if self.consume_token(TokenKind::Comma) { + continue; + } else if self.consume_token(TokenKind::RParen) { + break; + } else { + bail!("expected ',' or ')' after function argument"); + } + } + Ok(args) + } + + fn parse_case_when_function(&mut self) -> Result { + let mut branches = Vec::new(); + let mut default = None; + if self.consume_token(TokenKind::RParen) { + bail!("case_when requires at least one branch"); + } + loop { + if let Some(Token::Ident(name)) = self.peek_token().cloned() { + if name == "_" { + self.pos += 1; + if !self.consume_token(TokenKind::Arrow) { + bail!("case_when default branch must use '->'"); + } + let result = self.parse_case_result_value()?; + if default.is_some() { + bail!("case_when default branch specified more than once"); + } + default = Some(Box::new(result)); + if !self.consume_token(TokenKind::RParen) { + bail!("case_when default branch must be last"); + } + break; + } + } + let condition = self.parse_case_condition()?; + if !self.consume_token(TokenKind::Arrow) { + bail!("case_when branches must use '->'"); + } + let result = self.parse_case_result_value()?; + branches.push((condition, result)); + if self.consume_token(TokenKind::Comma) { + continue; + } else if self.consume_token(TokenKind::RParen) { + break; + } else { + bail!("expected ',' or ')' after case_when branch"); + } + } + if branches.is_empty() && default.is_none() { + bail!("case_when requires at least one branch"); + } + Ok(ValueExpr::CaseWhen(branches, default)) + } + + fn parse_case_condition(&mut self) -> Result { + let start = self.pos; + let mut depth = 0; + let mut idx = self.pos; + while idx < self.tokens.len() { + match self.tokens[idx] { + Token::LParen | Token::LBracket => depth += 1, + Token::RParen | Token::RBracket => { + if depth == 0 { + break; + } else { + depth -= 1; + } + } + Token::Arrow if depth == 0 => { + let slice = self.tokens[start..idx].to_vec(); + let mut parser = Parser::new(slice); + let expr = parser.parse_expr()?; + if parser.has_more() { + bail!("unexpected token in case_when condition"); + } + self.pos = idx; + return Ok(expr); + } + _ => {} + } + idx += 1; + } + bail!("case_when branches must use '->'") + } + + fn parse_case_result_value(&mut self) -> Result { + let start = self.pos; + let mut depth = 0; + let mut idx = self.pos; + while idx < self.tokens.len() { + match self.tokens[idx] { + Token::LParen | Token::LBracket => depth += 1, + Token::RParen => { + if depth == 0 { + break; + } else { + depth -= 1; + } + } + Token::Comma if depth == 0 => { + break; + } + _ => {} + } + idx += 1; + } + if idx == start { + bail!("case_when result must not be empty"); + } + let slice = self.tokens[start..idx].to_vec(); + let mut parser = Parser::new(slice); + let value = parser.parse_arith()?; + if parser.has_more() { + bail!("unexpected token in case_when result"); + } + self.pos = idx; + Ok(value) + } + + fn parse_switch_function(&mut self) -> Result { + let target = self.parse_arith()?; + if !self.consume_token(TokenKind::Comma) { + bail!("switch() requires a comma after the target expression"); + } + let mut branches = Vec::new(); + let mut default = None; + if self.consume_token(TokenKind::RParen) { + bail!("switch requires at least one branch"); + } + loop { + if let Some(Token::Ident(name)) = self.peek_token().cloned() { + if name == "_" { + self.pos += 1; + if !self.consume_token(TokenKind::Arrow) { + bail!("switch default branch must use '->'"); + } + let result = self.parse_arith()?; + if default.is_some() { + bail!("switch default branch specified more than once"); + } + default = Some(Box::new(result)); + if !self.consume_token(TokenKind::RParen) { + bail!("switch default branch must be last"); + } + break; + } + } + let values = self.parse_switch_values()?; + if !self.consume_token(TokenKind::Arrow) { + bail!("switch branches must use '->'"); + } + let result = self.parse_arith()?; + branches.push((values, result)); + if self.consume_token(TokenKind::Comma) { + continue; + } else if self.consume_token(TokenKind::RParen) { + break; + } else { + bail!("expected ',' or ')' after switch branch"); + } + } + if branches.is_empty() && default.is_none() { + bail!("switch requires at least one branch"); + } + Ok(ValueExpr::Switch { + target: Box::new(target), + branches, + default, + }) + } + + fn parse_switch_values(&mut self) -> Result> { + if self.consume_token(TokenKind::LBracket) { + let mut values = Vec::new(); + if self.consume_token(TokenKind::RBracket) { + bail!("switch value list must not be empty"); + } + loop { + let value = self.parse_arith()?; + values.push(value); + if self.consume_token(TokenKind::Comma) { + continue; + } else if self.consume_token(TokenKind::RBracket) { + break; + } else { + bail!("expected ',' or ']' in switch value list"); + } + } + Ok(values) + } else { + let value = self.parse_arith()?; + Ok(vec![value]) + } + } + fn consume_compare(&mut self) -> Option { if let Some(Token::Compare(op)) = self.peek_token().cloned() { self.pos += 1; @@ -1308,6 +1937,10 @@ impl Parser { (TokenKind::Not, Some(Token::Not)) => true, (TokenKind::LParen, Some(Token::LParen)) => true, (TokenKind::RParen, Some(Token::RParen)) => true, + (TokenKind::LBracket, Some(Token::LBracket)) => true, + (TokenKind::RBracket, Some(Token::RBracket)) => true, + (TokenKind::Comma, Some(Token::Comma)) => true, + (TokenKind::Arrow, Some(Token::Arrow)) => true, (TokenKind::Plus, Some(Token::Plus)) => true, (TokenKind::Minus, Some(Token::Minus)) => true, (TokenKind::Star, Some(Token::Star)) => true, @@ -1350,6 +1983,10 @@ enum TokenKind { Not, LParen, RParen, + LBracket, + RBracket, + Comma, + Arrow, Plus, Minus, Star, diff --git a/src/mutate.rs b/src/mutate.rs index fb23305..9c38830 100644 --- a/src/mutate.rs +++ b/src/mutate.rs @@ -652,6 +652,14 @@ mod tests { assert_eq!(row[0], "hello/world"); } + #[test] + fn case_when_expression_parses_with_named_columns() { + let headers = vec!["sample_id".to_string()]; + let expr = "label = case_when(\n re($sample_id, \"^ERR(\\\\d+)$\") -> $1,\n _ -> $sample_id\n )"; + let ops = parse_operations(&[expr.to_string()], &headers, false).unwrap(); + assert_eq!(ops.len(), 1); + } + #[test] fn parse_string_literal_preserves_regex_escapes() { let literal = parse_string_literal("\"\\\\.\\\\*\"").unwrap(); From 40e830c2111e4519654ff6e0acae6be67e650ee8 Mon Sep 17 00:00:00 2001 From: "Z.-L. Deng" Date: Thu, 9 Oct 2025 09:45:20 +0200 Subject: [PATCH 4/9] Fix column selector heuristics in case_when --- Cargo.lock | 2 +- Cargo.toml | 2 +- README.md | 31 ++ src/expression.rs | 857 ++++++++++++++++++++++++++++++++++++++++++---- src/mutate.rs | 16 + 5 files changed, 845 insertions(+), 63 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5f4ca84..6d80c91 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -603,7 +603,7 @@ dependencies = [ [[package]] name = "tsvkit" -version = "0.9.4" +version = "0.9.5" dependencies = [ "anyhow", "calamine", diff --git a/Cargo.toml b/Cargo.toml index 8e88aed..9be5e12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tsvkit" -version = "0.9.4" +version = "0.9.5" edition = "2024" [dependencies] diff --git a/README.md b/README.md index e7c23b8..3be8153 100644 --- a/README.md +++ b/README.md @@ -165,9 +165,27 @@ The same expression language powers `filter -e`, `mutate -e name=EXPR`, and rege | `ln(expr)` | Natural logarithm | | `log(expr)` / `log10(expr)` | Base-10 logarithm | | `log2(expr)` | Base-2 logarithm | +| `len(expr)` | Character count using Unicode code points. | +| `is_na(expr)` | Returns `1` when the expression is blank/`NA`/`NaN`, otherwise `0`. | Functions accept column references (`abs($purity - 1)`), constants, or subexpressions. Empty or non-numeric values yield blanks. +**Conditional and regex helpers** + +- `case_when(condition -> result, ..., _ -> default)` evaluates each boolean condition in order and returns the matching result. The final `_` branch acts as the default. +- `switch(value, [match1, match2] -> result, ..., _ -> default)` compares `value` to one or more literal matches (strings or numbers) and returns the corresponding result. +- `re(value, pattern)` evaluates a regex against `value`, returning `1` or `0`. When the pattern matches, capture groups become available as `$1`, `$2`, etc. for the remainder of the expression (use `$0`-style numeric selectors sparingly when you rely on captures). + +Example: + +``` +case_when( + re($sample, "^ERR(\\d+)$") -> $1, + re($sample, "^SRR") -> "SRA", + _ -> $sample +) +``` + **Row-wise aggregation helpers** Available within `mutate` expressions via functions such as `sum($col1:$col5)`; see the [Mutate](#mutate) section for the full list. @@ -260,6 +278,19 @@ tsvkit mutate \ examples/cytokines.tsv ``` +Use `case_when`, `switch`, and the `re()` helper for richer branching logic and regex capture reuse: + +```bash +tsvkit mutate \ + -e 'label = case_when( + re($sample, "^ERR(\d+)$") -> $1, + re($sample, "^SRR") -> "SRA", + _ -> $sample + )' \ + -e 'bucket = case_when(len($clean) == 0 -> "empty", len($clean) < 5 -> "short", _ -> "long")' \ + examples/samples.tsv +``` + Apply in-place edits with the sed-style form: ```bash diff --git a/src/expression.rs b/src/expression.rs index 34cacfe..7063414 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -28,6 +28,13 @@ pub enum ValueExpr { Binary(BinaryOp, Box, Box), Function(FunctionName, Box), Aggregate(AggregateSpecExpr), + CaseWhen(Vec<(Expr, ValueExpr)>, Option>), + Switch { + target: Box, + branches: Vec<(Vec, ValueExpr)>, + default: Option>, + }, + RegexCall(Box, Box), } #[derive(Debug, Clone)] @@ -59,6 +66,8 @@ pub enum FunctionName { Ln, Log10, Log2, + Len, + IsNa, } impl FunctionName { @@ -71,8 +80,10 @@ impl FunctionName { "ln" => Ok(FunctionName::Ln), "log" | "log10" => Ok(FunctionName::Log10), "log2" => Ok(FunctionName::Log2), + "len" => Ok(FunctionName::Len), + "is_na" => Ok(FunctionName::IsNa), other => bail!( - "unsupported function '{}': try abs, sqrt, exp, exp2, ln, log, log10, log2", + "unsupported function '{}': try abs, sqrt, exp, exp2, ln, log, log10, log2, len, is_na", other ), } @@ -100,6 +111,10 @@ enum Token { Ident(String), LParen, RParen, + LBracket, + RBracket, + Comma, + Arrow, And, Or, Not, @@ -212,6 +227,19 @@ pub enum BoundValue { Binary(BinaryOp, Box, Box), Function(FunctionName, Box), Aggregate(BoundAggregate), + CaseWhen { + branches: Vec<(BoundExpr, BoundValue)>, + default: Option>, + }, + Switch { + target: Box, + branches: Vec<(Vec, BoundValue)>, + default: Option>, + }, + RegexCall { + value: Box, + pattern: RegexPattern, + }, } #[derive(Debug, Clone)] @@ -220,7 +248,7 @@ pub struct BoundAggregate { pub columns: Vec, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum BoundExpr { Or(Box, Box), And(Box, Box), @@ -234,7 +262,7 @@ pub enum BoundExpr { Value(BoundValue), } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) enum RegexPattern { Static(Arc), Dynamic(Box), @@ -245,31 +273,105 @@ pub struct EvalValue<'a> { pub numeric: Option, } +pub struct EvalContext<'a, R> +where + R: RowAccessor + ?Sized, +{ + row: &'a R, + regex_captures: Option>, +} + +impl<'a, R> EvalContext<'a, R> +where + R: RowAccessor + ?Sized, +{ + pub fn new(row: &'a R) -> Self { + EvalContext { + row, + regex_captures: None, + } + } + + fn row(&self) -> &'a R { + self.row + } + + fn clear_captures(&mut self) { + self.regex_captures = None; + } + + fn set_captures(&mut self, captures: Vec) { + self.regex_captures = Some(captures); + } + + fn take_captures(&mut self) -> Option> { + self.regex_captures.take() + } + + fn restore_captures(&mut self, captures: Option>) { + self.regex_captures = captures; + } +} + pub fn evaluate(expr: &BoundExpr, row: &R) -> bool +where + R: RowAccessor + ?Sized, +{ + let mut ctx = EvalContext::new(row); + evaluate_with_context(expr, &mut ctx) +} + +fn evaluate_with_context<'a, R>(expr: &'a BoundExpr, ctx: &mut EvalContext<'a, R>) -> bool where R: RowAccessor + ?Sized, { match expr { - BoundExpr::Or(lhs, rhs) => evaluate(lhs, row) || evaluate(rhs, row), - BoundExpr::And(lhs, rhs) => evaluate(lhs, row) && evaluate(rhs, row), - BoundExpr::Not(inner) => !evaluate(inner, row), - BoundExpr::Compare(lhs, op, rhs) => evaluate_compare(lhs, *op, rhs, row), + BoundExpr::Or(lhs, rhs) => { + evaluate_with_context(lhs, ctx) || evaluate_with_context(rhs, ctx) + } + BoundExpr::And(lhs, rhs) => { + evaluate_with_context(lhs, ctx) && evaluate_with_context(rhs, ctx) + } + BoundExpr::Not(inner) => !evaluate_with_context(inner, ctx), + BoundExpr::Compare(lhs, op, rhs) => evaluate_compare(lhs, *op, rhs, ctx), BoundExpr::RegexMatch { value, pattern, invert, - } => evaluate_regex(value, pattern, *invert, row), - BoundExpr::Value(value) => evaluate_truthy(value, row), + } => evaluate_regex(value, pattern, *invert, ctx), + BoundExpr::Value(value) => evaluate_truthy_with_context(value, ctx), } } pub fn eval_value<'a, R>(value: &'a BoundValue, row: &'a R) -> EvalValue<'a> +where + R: RowAccessor + ?Sized, +{ + let mut ctx = EvalContext::new(row); + eval_value_with_context(value, &mut ctx) +} + +pub fn eval_value_with_context<'a, R>( + value: &'a BoundValue, + ctx: &mut EvalContext<'a, R>, +) -> EvalValue<'a> where R: RowAccessor + ?Sized, { match value { BoundValue::Column(idx) => { - let text = row.get(*idx).unwrap_or(""); + if let Some(captures) = ctx.regex_captures.as_ref() { + let capture_idx = idx + 1; + if capture_idx < captures.len() { + let capture = captures[capture_idx].clone(); + let numeric = parse_float(&capture); + return EvalValue { + text: Cow::Owned(capture), + numeric, + }; + } + } + let text = ctx.row().get(*idx).unwrap_or(""); EvalValue { text: Cow::Borrowed(text), numeric: parse_float(text), @@ -282,7 +384,7 @@ where let mut combined = String::new(); let mut numeric = None; for (pos, idx) in indices.iter().enumerate() { - let text = row.get(*idx).unwrap_or(""); + let text = ctx.row().get(*idx).unwrap_or(""); if pos > 0 { combined.push('\t'); } @@ -302,7 +404,7 @@ where }, BoundValue::Number(number) => numeric_eval(*number), BoundValue::Unary(op, inner) => { - let inner_eval = eval_value(inner, row); + let inner_eval = eval_value_with_context(inner, ctx); if let Some(val) = inner_eval.numeric { numeric_eval(match op { UnaryOp::Neg => -val, @@ -312,8 +414,8 @@ where } } BoundValue::Binary(op, left, right) => { - let left_eval = eval_value(left, row); - let right_eval = eval_value(right, row); + let left_eval = eval_value_with_context(left, ctx); + let right_eval = eval_value_with_context(right, ctx); match (left_eval.numeric, right_eval.numeric) { (Some(a), Some(b)) => match op { BinaryOp::Add => numeric_eval(a + b), @@ -339,36 +441,58 @@ where } } BoundValue::Function(func, inner) => { - let inner_eval = eval_value(inner, row); - if let Some(val) = inner_eval.numeric { - let result = match func { - FunctionName::Abs => Some(val.abs()), - FunctionName::Sqrt => { - let value = val.sqrt(); - value.is_finite().then_some(value) - } - FunctionName::Exp => { - let value = val.exp(); - value.is_finite().then_some(value) - } - FunctionName::Exp2 => { - let value = val.exp2(); - value.is_finite().then_some(value) + let inner_eval = eval_value_with_context(inner, ctx); + match func { + FunctionName::Abs + | FunctionName::Sqrt + | FunctionName::Exp + | FunctionName::Exp2 + | FunctionName::Ln + | FunctionName::Log10 + | FunctionName::Log2 => { + if let Some(val) = inner_eval.numeric { + let result = match func { + FunctionName::Abs => Some(val.abs()), + FunctionName::Sqrt => { + let value = val.sqrt(); + value.is_finite().then_some(value) + } + FunctionName::Exp => { + let value = val.exp(); + value.is_finite().then_some(value) + } + FunctionName::Exp2 => { + let value = val.exp2(); + value.is_finite().then_some(value) + } + FunctionName::Ln => (val > 0.0).then(|| val.ln()), + FunctionName::Log10 => (val > 0.0).then(|| val.log10()), + FunctionName::Log2 => (val > 0.0).then(|| val.log2()), + _ => None, + }; + result.map(numeric_eval).unwrap_or_else(empty_eval) + } else { + empty_eval() } - FunctionName::Ln => (val > 0.0).then(|| val.ln()), - FunctionName::Log10 => (val > 0.0).then(|| val.log10()), - FunctionName::Log2 => (val > 0.0).then(|| val.log2()), - }; - result.map(numeric_eval).unwrap_or_else(empty_eval) - } else { - empty_eval() + } + FunctionName::Len => { + let len = inner_eval.text.chars().count() as f64; + numeric_eval(len) + } + FunctionName::IsNa => { + let text = inner_eval.text.as_ref().trim(); + let is_na = text.is_empty() + || text.eq_ignore_ascii_case("na") + || text.eq_ignore_ascii_case("nan"); + bool_eval(is_na) + } } } BoundValue::Aggregate(spec) => { let values = spec .columns .iter() - .map(|&idx| row.get(idx).unwrap_or("")) + .map(|&idx| ctx.row().get(idx).unwrap_or("")) .collect::>(); let result = evaluate_row_aggregate(&spec.kind, &values); EvalValue { @@ -376,21 +500,95 @@ where numeric: result.numeric, } } + BoundValue::CaseWhen { branches, default } => { + for (cond, result) in branches { + ctx.clear_captures(); + if evaluate_with_context(cond, ctx) { + return eval_value_with_context(result, ctx); + } + } + ctx.clear_captures(); + if let Some(default) = default { + eval_value_with_context(default, ctx) + } else { + empty_eval() + } + } + BoundValue::Switch { + target, + branches, + default, + } => { + let target_eval = eval_value_with_context(target, ctx); + let target_numeric = target_eval.numeric; + let target_text = target_eval.text.into_owned(); + ctx.clear_captures(); + for (values, result) in branches { + for value in values { + let saved = ctx.take_captures(); + let candidate = eval_value_with_context(value, ctx); + ctx.restore_captures(saved); + let is_match = match (target_numeric, candidate.numeric) { + (Some(a), Some(b)) => a == b, + _ => target_text == candidate.text.as_ref(), + }; + if is_match { + ctx.clear_captures(); + return eval_value_with_context(result, ctx); + } + } + } + ctx.clear_captures(); + if let Some(default) = default { + eval_value_with_context(default, ctx) + } else { + empty_eval() + } + } + BoundValue::RegexCall { value, pattern } => { + let hay = eval_value_with_context(value, ctx); + let hay_text = hay.text.into_owned(); + let captures = match pattern { + RegexPattern::Static(regex) => regex.captures(&hay_text), + RegexPattern::Dynamic(bound) => { + let pat_eval = eval_value_with_context(bound, ctx); + Regex::new(pat_eval.text.as_ref()) + .ok() + .and_then(|regex| regex.captures(&hay_text)) + } + }; + if let Some(captures) = captures { + let mut values = Vec::with_capacity(captures.len()); + for idx in 0..captures.len() { + let text = captures.get(idx).map(|m| m.as_str()).unwrap_or(""); + values.push(text.to_string()); + } + ctx.set_captures(values); + bool_eval(true) + } else { + ctx.clear_captures(); + bool_eval(false) + } + } } } -pub fn evaluate_truthy(value: &BoundValue, row: &R) -> bool +pub fn evaluate_truthy_with_context<'a, R>( + value: &'a BoundValue, + ctx: &mut EvalContext<'a, R>, +) -> bool where R: RowAccessor + ?Sized, { match value { BoundValue::Columns(indices) => indices.iter().any(|idx| { - row.get(*idx) + ctx.row() + .get(*idx) .map(|text| !text.trim().is_empty()) .unwrap_or(false) }), _ => { - let eval = eval_value(value, row); + let eval = eval_value_with_context(value, ctx); if let Some(number) = eval.numeric { number != 0.0 } else { @@ -400,12 +598,17 @@ where } } -fn evaluate_compare(lhs: &BoundValue, op: CompareOp, rhs: &BoundValue, row: &R) -> bool +fn evaluate_compare<'a, R>( + lhs: &'a BoundValue, + op: CompareOp, + rhs: &'a BoundValue, + ctx: &mut EvalContext<'a, R>, +) -> bool where R: RowAccessor + ?Sized, { - let left = eval_value(lhs, row); - let right = eval_value(rhs, row); + let left = eval_value_with_context(lhs, ctx); + let right = eval_value_with_context(rhs, ctx); match op { CompareOp::Eq => { @@ -430,32 +633,37 @@ where } } -fn evaluate_regex(value: &BoundValue, pattern: &RegexPattern, invert: bool, row: &R) -> bool +fn evaluate_regex<'a, R>( + value: &'a BoundValue, + pattern: &'a RegexPattern, + invert: bool, + ctx: &mut EvalContext<'a, R>, +) -> bool where R: RowAccessor + ?Sized, { let is_match = match pattern { RegexPattern::Static(regex) => match value { - BoundValue::Column(idx) => regex.is_match(row.get(*idx).unwrap_or("")), + BoundValue::Column(idx) => regex.is_match(ctx.row().get(*idx).unwrap_or("")), BoundValue::Columns(indices) => indices .iter() - .any(|idx| regex.is_match(row.get(*idx).unwrap_or(""))), + .any(|idx| regex.is_match(ctx.row().get(*idx).unwrap_or(""))), _ => { - let hay = eval_value(value, row); + let hay = eval_value_with_context(value, ctx); regex.is_match(hay.text.as_ref()) } }, RegexPattern::Dynamic(bound) => { - let pat_eval = eval_value(bound, row); + let pat_eval = eval_value_with_context(bound, ctx); let pattern_text = pat_eval.text.as_ref(); if let Ok(regex) = Regex::new(pattern_text) { match value { - BoundValue::Column(idx) => regex.is_match(row.get(*idx).unwrap_or("")), + BoundValue::Column(idx) => regex.is_match(ctx.row().get(*idx).unwrap_or("")), BoundValue::Columns(indices) => indices .iter() - .any(|idx| regex.is_match(row.get(*idx).unwrap_or(""))), + .any(|idx| regex.is_match(ctx.row().get(*idx).unwrap_or(""))), _ => { - let hay = eval_value(value, row); + let hay = eval_value_with_context(value, ctx); regex.is_match(hay.text.as_ref()) } } @@ -528,6 +736,55 @@ fn bind_value(value: ValueExpr, headers: &[String], no_header: bool) -> Result { + let mut bound_branches = Vec::with_capacity(branches.len()); + for (cond, result) in branches { + let bound_cond = bind_expression(cond, headers, no_header)?; + let bound_result = bind_value(result, headers, no_header)?; + bound_branches.push((bound_cond, bound_result)); + } + let bound_default = match default { + Some(expr) => Some(Box::new(bind_value(*expr, headers, no_header)?)), + None => None, + }; + Ok(BoundValue::CaseWhen { + branches: bound_branches, + default: bound_default, + }) + } + ValueExpr::Switch { + target, + branches, + default, + } => { + let bound_target = bind_value(*target, headers, no_header)?; + let mut bound_branches = Vec::with_capacity(branches.len()); + for (values, result) in branches { + let mut bound_values = Vec::with_capacity(values.len()); + for value in values { + bound_values.push(bind_value(value, headers, no_header)?); + } + let bound_result = bind_value(result, headers, no_header)?; + bound_branches.push((bound_values, bound_result)); + } + let bound_default = match default { + Some(expr) => Some(Box::new(bind_value(*expr, headers, no_header)?)), + None => None, + }; + Ok(BoundValue::Switch { + target: Box::new(bound_target), + branches: bound_branches, + default: bound_default, + }) + } + ValueExpr::RegexCall(value, pattern) => { + let bound_value = bind_value(*value, headers, no_header)?; + let bound_pattern = bind_regex_pattern(*pattern, headers, no_header)?; + Ok(BoundValue::RegexCall { + value: Box::new(bound_value), + pattern: bound_pattern, + }) + } } } @@ -570,6 +827,14 @@ fn numeric_eval(value: f64) -> EvalValue<'static> { } } +fn bool_eval(value: bool) -> EvalValue<'static> { + if value { + numeric_eval(1.0) + } else { + numeric_eval(0.0) + } +} + fn empty_eval<'a>() -> EvalValue<'a> { EvalValue { text: Cow::Owned(String::new()), @@ -729,6 +994,84 @@ mod tests { other => panic!("expected string literal, got {:?}", other), } } + + #[test] + fn case_when_selects_first_matching_branch() { + let value_expr = + parse_value_expression("case_when($1 > 5 -> \"high\", _ -> \"low\")").unwrap(); + let headers = vec!["score".to_string()]; + let bound = bind_value_expression(value_expr, &headers, false).unwrap(); + + let row_high = vec!["6".to_string()]; + assert_eq!(eval_value(&bound, &row_high).text.as_ref(), "high"); + + let row_low = vec!["2".to_string()]; + assert_eq!(eval_value(&bound, &row_low).text.as_ref(), "low"); + } + + #[test] + fn regex_call_populates_capture_groups() { + let value_expr = + parse_value_expression("case_when(re($1, \"^ERR(\\\\d+)$\") -> $1, _ -> \"nomatch\")") + .unwrap(); + let headers = vec!["sample".to_string()]; + let bound = bind_value_expression(value_expr, &headers, false).unwrap(); + + let matched = vec!["ERR123".to_string()]; + assert_eq!(eval_value(&bound, &matched).text.as_ref(), "123"); + + let unmatched = vec!["SRR55".to_string()]; + assert_eq!(eval_value(&bound, &unmatched).text.as_ref(), "nomatch"); + } + + #[test] + fn case_when_with_named_columns_and_regex_capture() { + let source = "case_when(\n re($sample_id, \"^ERR(\\\\d+)$\") -> $1,\n re($sample_id, \"^SRR\") -> \"SRA\",\n _ -> $sample_id\n )"; + let expr = parse_value_expression(source).unwrap(); + let headers = vec!["sample_id".to_string()]; + let bound = bind_value_expression(expr, &headers, false).unwrap(); + + let row_err = vec!["ERR1234".to_string()]; + assert_eq!(eval_value(&bound, &row_err).text.as_ref(), "1234"); + + let row_srr = vec!["SRR9000".to_string()]; + assert_eq!(eval_value(&bound, &row_srr).text.as_ref(), "SRA"); + + let row_other = vec!["OTHER".to_string()]; + assert_eq!(eval_value(&bound, &row_other).text.as_ref(), "OTHER"); + } + + #[test] + fn case_when_with_column_result_followed_by_column_condition() { + let source = "case_when($score < 0 -> \"neg\", is_na($score) -> $score, $score < 50 -> \"low\", _ -> \"high\")"; + parse_value_expression(source).unwrap(); + } + + #[test] + fn switch_maps_values_to_labels() { + let value_expr = parse_value_expression( + "switch($1, [\"DE\",\"FR\"] -> \"EU\", [\"US\",\"CA\"] -> \"NA\", _ -> \"Other\")", + ) + .unwrap(); + let headers = vec!["country".to_string()]; + let bound = bind_value_expression(value_expr, &headers, false).unwrap(); + + let row_eu = vec!["DE".to_string()]; + assert_eq!(eval_value(&bound, &row_eu).text.as_ref(), "EU"); + + let row_other = vec!["JP".to_string()]; + assert_eq!(eval_value(&bound, &row_other).text.as_ref(), "Other"); + } + + #[test] + fn len_function_counts_characters() { + let value_expr = parse_value_expression("len($1)").unwrap(); + let headers = vec!["text".to_string()]; + let bound = bind_value_expression(value_expr, &headers, false).unwrap(); + let row = vec!["hello".to_string()]; + let result = eval_value(&bound, &row); + assert_eq!(result.numeric, Some(5.0)); + } } impl<'a> Lexer<'a> { @@ -812,13 +1155,30 @@ impl<'a> Lexer<'a> { self.pos += 1; Ok(Some(Token::RParen)) } + b'[' => { + self.pos += 1; + Ok(Some(Token::LBracket)) + } + b']' => { + self.pos += 1; + Ok(Some(Token::RBracket)) + } + b',' => { + self.pos += 1; + Ok(Some(Token::Comma)) + } b'+' => { self.pos += 1; Ok(Some(Token::Plus)) } b'-' => { - self.pos += 1; - Ok(Some(Token::Minus)) + if self.peek_char(1) == Some(b'>') { + self.pos += 2; + Ok(Some(Token::Arrow)) + } else { + self.pos += 1; + Ok(Some(Token::Minus)) + } } b'*' => { self.pos += 1; @@ -875,13 +1235,18 @@ impl<'a> Lexer<'a> { continue; } if c.is_ascii_alphanumeric() || matches!(c, b'_' | b'.' | b',' | b':') { - if c == b',' || c == b':' { + if c == b',' { + if self.should_continue_selector_list(self.pos) { + has_range_syntax = true; + is_numeric = false; + self.pos += 1; + continue; + } else { + break; + } + } else if c == b':' { has_range_syntax = true; is_numeric = false; - } else if !c.is_ascii_digit() { - is_numeric = false; - } - if c == b',' || c == b':' { self.pos += 1; continue; } @@ -1033,6 +1398,141 @@ impl<'a> Lexer<'a> { None } + fn should_continue_selector_list(&self, comma_pos: usize) -> bool { + let mut idx = comma_pos + 1; + while idx < self.chars.len() && self.chars[idx].is_ascii_whitespace() { + idx += 1; + } + if idx >= self.chars.len() { + return false; + } + let start = self.chars[idx]; + if start == b'_' { + return false; + } + let Some(end) = self.selector_end(idx) else { + return false; + }; + if self.branch_arrow_ahead(end) { + return false; + } + true + } + + fn selector_end(&self, start: usize) -> Option { + let len = self.chars.len(); + if start >= len { + return None; + } + let mut pos = start; + match self.chars[pos] { + b'$' => { + pos += 1; + if pos < len && self.chars[pos] == b'{' { + pos += 1; + let mut escaped = false; + while pos < len { + let c = self.chars[pos]; + if escaped { + escaped = false; + } else if c == b'\\' { + escaped = true; + } else if c == b'}' { + pos += 1; + break; + } + pos += 1; + } + } else { + while pos < len + && (self.chars[pos].is_ascii_alphanumeric() + || matches!(self.chars[pos], b'_' | b'.')) + { + pos += 1; + } + } + } + b'-' => { + pos += 1; + if pos >= len || !self.chars[pos].is_ascii_digit() { + return None; + } + while pos < len && self.chars[pos].is_ascii_digit() { + pos += 1; + } + } + c if c.is_ascii_digit() => { + pos += 1; + while pos < len && self.chars[pos].is_ascii_digit() { + pos += 1; + } + } + _ => return None, + } + let mut end = pos; + while end < len && self.chars[end].is_ascii_whitespace() { + end += 1; + } + if end < len && self.chars[end] == b':' { + end += 1; + while end < len && self.chars[end].is_ascii_whitespace() { + end += 1; + } + if let Some(range_end) = self.selector_end(end) { + end = range_end; + } + } + Some(end) + } + + fn branch_arrow_ahead(&self, mut idx: usize) -> bool { + let len = self.chars.len(); + let mut depth = 0usize; + while idx < len { + let c = self.chars[idx]; + match c { + b'"' | b'\'' => { + idx = self.skip_quoted_literal(idx); + } + b'(' | b'[' => { + depth += 1; + idx += 1; + } + b')' | b']' => { + if depth == 0 { + return false; + } + depth -= 1; + idx += 1; + } + b',' if depth == 0 => return false, + b'-' if depth == 0 && idx + 1 < len && self.chars[idx + 1] == b'>' => return true, + _ => { + idx += 1; + } + } + } + false + } + + fn skip_quoted_literal(&self, start: usize) -> usize { + let len = self.chars.len(); + let quote = self.chars[start]; + let mut idx = start + 1; + while idx < len { + let c = self.chars[idx]; + if c == b'\\' { + idx += 2; + continue; + } + if c == quote { + return idx + 1; + } + idx += 1; + } + len + } + fn skip_whitespace(&mut self) { while self.pos < self.chars.len() && self.chars[self.pos].is_ascii_whitespace() { self.pos += 1; @@ -1256,11 +1756,29 @@ impl Parser { name ); } - let argument = self.parse_arith()?; - if !self.consume_token(TokenKind::RParen) { - bail!("missing ')' after function call"); + let lower = name.to_ascii_lowercase(); + if lower == "case_when" { + let expr = self.parse_case_when_function()?; + return Ok(expr); + } + if lower == "switch" { + let expr = self.parse_switch_function()?; + return Ok(expr); + } + if lower == "re" { + let mut args = self.parse_function_arguments()?; + if args.len() != 2 { + bail!("re() expects two arguments: value, pattern"); + } + let pattern = args.pop().unwrap(); + let value = args.pop().unwrap(); + return Ok(ValueExpr::RegexCall(Box::new(value), Box::new(pattern))); } if let Some(kind) = try_parse_aggregate_kind(&name)? { + let argument = self.parse_arith()?; + if !self.consume_token(TokenKind::RParen) { + bail!("missing ')' after function call"); + } let selectors = match argument { ValueExpr::Column(selector) => vec![selector], ValueExpr::Columns(list) => list, @@ -1274,6 +1792,11 @@ impl Parser { }; return Ok(ValueExpr::Aggregate(AggregateSpecExpr { kind, selectors })); } + let mut args = self.parse_function_arguments()?; + if args.len() != 1 { + bail!("function '{}' expects exactly one argument", name); + } + let argument = args.pop().unwrap(); let func = FunctionName::from_ident(&name)?; Ok(ValueExpr::Function(func, Box::new(argument))) } else { @@ -1292,6 +1815,210 @@ impl Parser { } } + fn parse_function_arguments(&mut self) -> Result> { + let mut args = Vec::new(); + if self.consume_token(TokenKind::RParen) { + return Ok(args); + } + loop { + let expr = self.parse_arith()?; + args.push(expr); + if self.consume_token(TokenKind::Comma) { + continue; + } else if self.consume_token(TokenKind::RParen) { + break; + } else { + bail!("expected ',' or ')' after function argument"); + } + } + Ok(args) + } + + fn parse_case_when_function(&mut self) -> Result { + let mut branches = Vec::new(); + let mut default = None; + if self.consume_token(TokenKind::RParen) { + bail!("case_when requires at least one branch"); + } + loop { + if let Some(Token::Ident(name)) = self.peek_token().cloned() { + if name == "_" { + self.pos += 1; + if !self.consume_token(TokenKind::Arrow) { + bail!("case_when default branch must use '->'"); + } + let result = self.parse_case_result_value()?; + if default.is_some() { + bail!("case_when default branch specified more than once"); + } + default = Some(Box::new(result)); + if !self.consume_token(TokenKind::RParen) { + bail!("case_when default branch must be last"); + } + break; + } + } + let condition = self.parse_case_condition()?; + if !self.consume_token(TokenKind::Arrow) { + bail!("case_when branches must use '->'"); + } + let result = self.parse_case_result_value()?; + branches.push((condition, result)); + if self.consume_token(TokenKind::Comma) { + continue; + } else if self.consume_token(TokenKind::RParen) { + break; + } else { + bail!("expected ',' or ')' after case_when branch"); + } + } + if branches.is_empty() && default.is_none() { + bail!("case_when requires at least one branch"); + } + Ok(ValueExpr::CaseWhen(branches, default)) + } + + fn parse_case_condition(&mut self) -> Result { + let start = self.pos; + let mut depth = 0; + let mut idx = self.pos; + while idx < self.tokens.len() { + match self.tokens[idx] { + Token::LParen | Token::LBracket => depth += 1, + Token::RParen | Token::RBracket => { + if depth == 0 { + break; + } else { + depth -= 1; + } + } + Token::Arrow if depth == 0 => { + let slice = self.tokens[start..idx].to_vec(); + let mut parser = Parser::new(slice); + let expr = parser.parse_expr()?; + if parser.has_more() { + bail!("unexpected token in case_when condition"); + } + self.pos = idx; + return Ok(expr); + } + _ => {} + } + idx += 1; + } + bail!("case_when branches must use '->'") + } + + fn parse_case_result_value(&mut self) -> Result { + let start = self.pos; + let mut depth = 0; + let mut idx = self.pos; + while idx < self.tokens.len() { + match self.tokens[idx] { + Token::LParen | Token::LBracket => depth += 1, + Token::RParen => { + if depth == 0 { + break; + } else { + depth -= 1; + } + } + Token::Comma if depth == 0 => { + break; + } + _ => {} + } + idx += 1; + } + if idx == start { + bail!("case_when result must not be empty"); + } + let slice = self.tokens[start..idx].to_vec(); + let mut parser = Parser::new(slice); + let value = parser.parse_arith()?; + if parser.has_more() { + bail!("unexpected token in case_when result"); + } + self.pos = idx; + Ok(value) + } + + fn parse_switch_function(&mut self) -> Result { + let target = self.parse_arith()?; + if !self.consume_token(TokenKind::Comma) { + bail!("switch() requires a comma after the target expression"); + } + let mut branches = Vec::new(); + let mut default = None; + if self.consume_token(TokenKind::RParen) { + bail!("switch requires at least one branch"); + } + loop { + if let Some(Token::Ident(name)) = self.peek_token().cloned() { + if name == "_" { + self.pos += 1; + if !self.consume_token(TokenKind::Arrow) { + bail!("switch default branch must use '->'"); + } + let result = self.parse_arith()?; + if default.is_some() { + bail!("switch default branch specified more than once"); + } + default = Some(Box::new(result)); + if !self.consume_token(TokenKind::RParen) { + bail!("switch default branch must be last"); + } + break; + } + } + let values = self.parse_switch_values()?; + if !self.consume_token(TokenKind::Arrow) { + bail!("switch branches must use '->'"); + } + let result = self.parse_arith()?; + branches.push((values, result)); + if self.consume_token(TokenKind::Comma) { + continue; + } else if self.consume_token(TokenKind::RParen) { + break; + } else { + bail!("expected ',' or ')' after switch branch"); + } + } + if branches.is_empty() && default.is_none() { + bail!("switch requires at least one branch"); + } + Ok(ValueExpr::Switch { + target: Box::new(target), + branches, + default, + }) + } + + fn parse_switch_values(&mut self) -> Result> { + if self.consume_token(TokenKind::LBracket) { + let mut values = Vec::new(); + if self.consume_token(TokenKind::RBracket) { + bail!("switch value list must not be empty"); + } + loop { + let value = self.parse_arith()?; + values.push(value); + if self.consume_token(TokenKind::Comma) { + continue; + } else if self.consume_token(TokenKind::RBracket) { + break; + } else { + bail!("expected ',' or ']' in switch value list"); + } + } + Ok(values) + } else { + let value = self.parse_arith()?; + Ok(vec![value]) + } + } + fn consume_compare(&mut self) -> Option { if let Some(Token::Compare(op)) = self.peek_token().cloned() { self.pos += 1; @@ -1308,6 +2035,10 @@ impl Parser { (TokenKind::Not, Some(Token::Not)) => true, (TokenKind::LParen, Some(Token::LParen)) => true, (TokenKind::RParen, Some(Token::RParen)) => true, + (TokenKind::LBracket, Some(Token::LBracket)) => true, + (TokenKind::RBracket, Some(Token::RBracket)) => true, + (TokenKind::Comma, Some(Token::Comma)) => true, + (TokenKind::Arrow, Some(Token::Arrow)) => true, (TokenKind::Plus, Some(Token::Plus)) => true, (TokenKind::Minus, Some(Token::Minus)) => true, (TokenKind::Star, Some(Token::Star)) => true, @@ -1350,6 +2081,10 @@ enum TokenKind { Not, LParen, RParen, + LBracket, + RBracket, + Comma, + Arrow, Plus, Minus, Star, diff --git a/src/mutate.rs b/src/mutate.rs index fb23305..3900bb7 100644 --- a/src/mutate.rs +++ b/src/mutate.rs @@ -652,6 +652,22 @@ mod tests { assert_eq!(row[0], "hello/world"); } + #[test] + fn case_when_expression_parses_with_named_columns() { + let headers = vec!["sample_id".to_string()]; + let expr = "label = case_when(\n re($sample_id, \"^ERR(\\\\d+)$\") -> $1,\n _ -> $sample_id\n )"; + let ops = parse_operations(&[expr.to_string()], &headers, false).unwrap(); + assert_eq!(ops.len(), 1); + } + + #[test] + fn case_when_expression_handles_adjacent_column_tokens() { + let headers = vec!["score".to_string()]; + let expr = "bucket = case_when($score < 0 -> \"neg\", is_na($score) -> $score, $score < 50 -> \"low\", $score < 80 -> \"mid\", _ -> \"high\")"; + let ops = parse_operations(&[expr.to_string()], &headers, false).unwrap(); + assert_eq!(ops.len(), 1); + } + #[test] fn parse_string_literal_preserves_regex_escapes() { let literal = parse_string_literal("\"\\\\.\\\\*\"").unwrap(); From 5e02a09919adeb4d1e8164cf869822c8f644c6c4 Mon Sep 17 00:00:00 2001 From: "Z.-L. Deng" Date: Thu, 9 Oct 2025 09:57:49 +0200 Subject: [PATCH 5/9] Fix evaluate_truthy regression and add coverage --- src/expression.rs | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/src/expression.rs b/src/expression.rs index 873f47a..eba1a5b 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -573,17 +573,6 @@ where } } -pub fn evaluate_truthy_with_context<'a, R>( - value: &'a BoundValue, - ctx: &mut EvalContext<'a, R>, -) -> bool -where - R: RowAccessor + ?Sized, -{ - let mut ctx = EvalContext::new(row); - evaluate_truthy_with_context(value, &mut ctx) -} - pub fn evaluate_truthy_with_context<'a, R>( value: &'a BoundValue, ctx: &mut EvalContext<'a, R>, @@ -609,6 +598,14 @@ where } } +pub fn evaluate_truthy<'a, R>(value: &'a BoundValue, row: &'a R) -> bool +where + R: RowAccessor + ?Sized, +{ + let mut ctx = EvalContext::new(row); + evaluate_truthy_with_context(value, &mut ctx) +} + fn evaluate_compare<'a, R>( lhs: &'a BoundValue, op: CompareOp, @@ -933,6 +930,19 @@ mod tests { assert!(evaluate(&bound, &record)); } + #[test] + fn evaluate_truthy_prefers_numeric_semantics_when_available() { + let value_expr = parse_value_expression("$1").unwrap(); + let headers = vec!["value".to_string()]; + let bound = bind_value_expression(value_expr, &headers, false).unwrap(); + + let truthy_row = vec!["5".to_string()]; + assert!(evaluate_truthy(&bound, &truthy_row)); + + let falsy_row = vec!["0".to_string()]; + assert!(!evaluate_truthy(&bound, &falsy_row)); + } + #[test] fn string_equality_remains_textual_when_not_numeric() { let expr = parse_expression("$1 == 0").unwrap(); From e428b1e44cfa6c7793f57bf5d7cd8ee7a46ab140 Mon Sep 17 00:00:00 2001 From: "Z.-L. Deng" Date: Thu, 9 Oct 2025 16:11:22 +0800 Subject: [PATCH 6/9] Update expression.rs --- src/expression.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/expression.rs b/src/expression.rs index eba1a5b..347f18c 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -598,6 +598,7 @@ where } } +#[allow(dead_code)] pub fn evaluate_truthy<'a, R>(value: &'a BoundValue, row: &'a R) -> bool where R: RowAccessor + ?Sized, From a11d26ebdc4a5709f39f3a51bb8134b7062afd4b Mon Sep 17 00:00:00 2001 From: "Z.-L. Deng" Date: Thu, 9 Oct 2025 12:37:59 +0200 Subject: [PATCH 7/9] Add file metadata injection to cut command --- src/common.rs | 47 +++++++++++++++++- src/cut.rs | 134 ++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 165 insertions(+), 16 deletions(-) diff --git a/src/common.rs b/src/common.rs index 09ee1f3..ffbed4c 100644 --- a/src/common.rs +++ b/src/common.rs @@ -7,12 +7,28 @@ use csv::ReaderBuilder; use flate2::read::MultiGzDecoder; use xz2::read::XzDecoder; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SpecialColumn { + FilePath, + FileBase, +} + +impl SpecialColumn { + pub fn default_header(self) -> &'static str { + match self { + SpecialColumn::FilePath => "__file__", + SpecialColumn::FileBase => "__base__", + } + } +} + #[derive(Debug, Clone)] pub enum ColumnSelector { Index(usize), FromEnd(usize), Name(String), Range(Option>, Option>), + Special(SpecialColumn), } pub fn parse_selector_list(spec: &str) -> Result> { @@ -86,6 +102,12 @@ pub fn resolve_selectors( let mut indices = Vec::with_capacity(selectors.len()); for selector in selectors { match selector { + ColumnSelector::Special(special) => { + bail!( + "special column '{}' cannot be resolved as a positional index", + special.default_header() + ); + } ColumnSelector::Index(_) | ColumnSelector::FromEnd(_) | ColumnSelector::Name(_) => { let index = resolve_selector_index(headers, selector, no_header)?; indices.push(index); @@ -266,6 +288,11 @@ fn parse_simple_selector(token: &str) -> Result { if let Some(literal) = parse_brace_literal(token)? { return Ok(ColumnSelector::Name(literal)); } + match token { + "__file__" => return Ok(ColumnSelector::Special(SpecialColumn::FilePath)), + "__base__" => return Ok(ColumnSelector::Special(SpecialColumn::FileBase)), + _ => {} + } if let Some(stripped) = token.strip_prefix('-') { if stripped.is_empty() { bail!("column selector '-' must include an index"); @@ -535,6 +562,10 @@ fn resolve_selector_index( .with_context(|| format!("column '{}' not found", name))?; Ok(index) } + ColumnSelector::Special(special) => bail!( + "special column '{}' not supported without column injection", + special.default_header() + ), ColumnSelector::Range(_, _) => { bail!("unexpected nested column range") } @@ -543,7 +574,10 @@ fn resolve_selector_index( #[cfg(test)] mod tests { - use super::{ColumnSelector, parse_selector_list, parse_single_selector, resolve_selectors}; + use super::{ + ColumnSelector, SpecialColumn, parse_selector_list, parse_single_selector, + resolve_selectors, + }; #[test] fn resolves_name_range() { @@ -657,4 +691,15 @@ mod tests { let err = parse_selector_list("`foo").unwrap_err(); assert!(err.to_string().contains("unterminated backtick")); } + + #[test] + fn distinguishes_injected_and_literal_file_columns() { + let selectors = parse_selector_list("__file__,{__file__},`__base__`").unwrap(); + assert!(matches!( + selectors[0], + ColumnSelector::Special(SpecialColumn::FilePath) + )); + assert!(matches!(selectors[1], ColumnSelector::Name(ref name) if name == "__file__")); + assert!(matches!(selectors[2], ColumnSelector::Name(ref name) if name == "__base__")); + } } diff --git a/src/cut.rs b/src/cut.rs index 50c5c44..4b2a5bc 100644 --- a/src/cut.rs +++ b/src/cut.rs @@ -1,12 +1,12 @@ use std::io::{self, BufWriter, Write}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; -use anyhow::{Context, Result}; +use anyhow::{Context, Result, bail}; use clap::Args; use crate::common::{ - InputOptions, default_headers, parse_selector_list, reader_for_path, resolve_selectors, - should_skip_record, + ColumnSelector, InputOptions, SpecialColumn, default_headers, parse_selector_list, + reader_for_path, resolve_selectors, should_skip_record, }; #[derive(Args, Debug)] @@ -23,6 +23,10 @@ pub struct CutArgs { #[arg(short = 'f', long = "fields", value_name = "COLS", required = true)] pub fields: String, + /// Rename the injected file column when using `__file__` or `__base__` + #[arg(long = "file-col", visible_alias = "fc", value_name = "NAME")] + pub file_col: Option, + /// Treat the input as headerless (columns referenced by 1-based indices) #[arg(short = 'H', long = "no-header")] pub no_header: bool, @@ -55,6 +59,9 @@ pub fn run(args: CutArgs) -> Result<()> { let mut reader = reader_for_path(&args.file, args.no_header, &input_opts)?; let mut writer = BufWriter::new(io::stdout().lock()); + let file_info = FileInfo::from_path(&args.file); + let file_column_config = FileColumnConfig::new(args.file_col.as_deref()); + if args.no_header { let mut records = reader.records(); let first_record = loop { @@ -72,14 +79,14 @@ pub fn run(args: CutArgs) -> Result<()> { }; let expected_width = first_record.len(); let headers = default_headers(expected_width); - let indices = resolve_selectors(&headers, &selectors, true)?; - emit_record(&first_record, &indices, &mut writer)?; + let columns = build_cut_columns(&headers, &selectors, true)?; + emit_record(&first_record, &columns, &file_info, &mut writer)?; for record in records { let record = record.with_context(|| format!("failed reading from {:?}", args.file))?; if should_skip_record(&record, &input_opts, Some(expected_width)) { continue; } - emit_record(&record, &indices, &mut writer)?; + emit_record(&record, &columns, &file_info, &mut writer)?; } } else { let headers = reader @@ -88,12 +95,19 @@ pub fn run(args: CutArgs) -> Result<()> { .iter() .map(|s| s.to_string()) .collect::>(); - let indices = resolve_selectors(&headers, &selectors, false)?; + let columns = build_cut_columns(&headers, &selectors, false)?; let expected_width = headers.len(); - let header_fields: Vec<&str> = indices + let header_fields: Vec = columns .iter() - .map(|&idx| headers.get(idx).map(|s| s.as_str()).unwrap_or("")) + .map(|column| match column { + CutColumn::Index(idx) => headers + .get(*idx) + .map(|s| s.as_str()) + .unwrap_or("") + .to_string(), + CutColumn::Injected(special) => file_column_config.header_for(*special), + }) .collect(); if !header_fields.is_empty() { writeln!(writer, "{}", header_fields.join("\t"))?; @@ -104,7 +118,7 @@ pub fn run(args: CutArgs) -> Result<()> { if should_skip_record(&record, &input_opts, Some(expected_width)) { continue; } - emit_record(&record, &indices, &mut writer)?; + emit_record(&record, &columns, &file_info, &mut writer)?; } } @@ -114,12 +128,16 @@ pub fn run(args: CutArgs) -> Result<()> { fn emit_record( record: &csv::StringRecord, - indices: &[usize], + columns: &[CutColumn], + file_info: &FileInfo, writer: &mut BufWriter>, ) -> Result<()> { - let mut fields = Vec::with_capacity(indices.len()); - for &idx in indices { - fields.push(record.get(idx).unwrap_or("")); + let mut fields = Vec::with_capacity(columns.len()); + for column in columns { + match column { + CutColumn::Index(idx) => fields.push(record.get(*idx).unwrap_or("")), + CutColumn::Injected(special) => fields.push(file_info.value_for(*special)), + } } if !fields.is_empty() { writeln!(writer, "{}", fields.join("\t"))?; @@ -128,3 +146,89 @@ fn emit_record( } Ok(()) } + +fn build_cut_columns( + headers: &[String], + selectors: &[ColumnSelector], + no_header: bool, +) -> Result> { + let mut columns = Vec::new(); + for selector in selectors { + match selector { + ColumnSelector::Special(special) => columns.push(CutColumn::Injected(*special)), + ColumnSelector::Range(start, end) => { + if start + .as_deref() + .map_or(false, |sel| matches!(sel, ColumnSelector::Special(_))) + || end + .as_deref() + .map_or(false, |sel| matches!(sel, ColumnSelector::Special(_))) + { + bail!("special columns cannot be used within a range selector"); + } + let indices = resolve_selectors(headers, &[selector.clone()], no_header)?; + columns.extend(indices.into_iter().map(CutColumn::Index)); + } + _ => { + let indices = resolve_selectors(headers, &[selector.clone()], no_header)?; + columns.extend(indices.into_iter().map(CutColumn::Index)); + } + } + } + Ok(columns) +} + +#[derive(Clone)] +struct FileInfo { + path: String, + base: String, +} + +impl FileInfo { + fn from_path(path: &Path) -> Self { + if path == Path::new("-") { + return FileInfo { + path: "-".to_string(), + base: "-".to_string(), + }; + } + let path_str = path.to_string_lossy().into_owned(); + let base = path + .file_name() + .map(|s| s.to_string_lossy().into_owned()) + .unwrap_or_else(|| path_str.clone()); + FileInfo { + path: path_str, + base, + } + } + + fn value_for(&self, special: SpecialColumn) -> &str { + match special { + SpecialColumn::FilePath => self.path.as_str(), + SpecialColumn::FileBase => self.base.as_str(), + } + } +} + +struct FileColumnConfig<'a> { + rename: Option<&'a str>, +} + +impl<'a> FileColumnConfig<'a> { + fn new(rename: Option<&'a str>) -> Self { + FileColumnConfig { rename } + } + + fn header_for(&self, special: SpecialColumn) -> String { + match self.rename { + Some(name) => name.to_string(), + None => special.default_header().to_string(), + } + } +} + +enum CutColumn { + Index(usize), + Injected(SpecialColumn), +} From f48f32cdcadc97242e43ac7a64b626bf9f739962 Mon Sep 17 00:00:00 2001 From: "Z.-L. Deng" Date: Thu, 9 Oct 2025 13:01:51 +0200 Subject: [PATCH 8/9] Allow cut to process multiple files --- src/common.rs | 47 ++++++++- src/cut.rs | 260 +++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 252 insertions(+), 55 deletions(-) diff --git a/src/common.rs b/src/common.rs index 09ee1f3..ffbed4c 100644 --- a/src/common.rs +++ b/src/common.rs @@ -7,12 +7,28 @@ use csv::ReaderBuilder; use flate2::read::MultiGzDecoder; use xz2::read::XzDecoder; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SpecialColumn { + FilePath, + FileBase, +} + +impl SpecialColumn { + pub fn default_header(self) -> &'static str { + match self { + SpecialColumn::FilePath => "__file__", + SpecialColumn::FileBase => "__base__", + } + } +} + #[derive(Debug, Clone)] pub enum ColumnSelector { Index(usize), FromEnd(usize), Name(String), Range(Option>, Option>), + Special(SpecialColumn), } pub fn parse_selector_list(spec: &str) -> Result> { @@ -86,6 +102,12 @@ pub fn resolve_selectors( let mut indices = Vec::with_capacity(selectors.len()); for selector in selectors { match selector { + ColumnSelector::Special(special) => { + bail!( + "special column '{}' cannot be resolved as a positional index", + special.default_header() + ); + } ColumnSelector::Index(_) | ColumnSelector::FromEnd(_) | ColumnSelector::Name(_) => { let index = resolve_selector_index(headers, selector, no_header)?; indices.push(index); @@ -266,6 +288,11 @@ fn parse_simple_selector(token: &str) -> Result { if let Some(literal) = parse_brace_literal(token)? { return Ok(ColumnSelector::Name(literal)); } + match token { + "__file__" => return Ok(ColumnSelector::Special(SpecialColumn::FilePath)), + "__base__" => return Ok(ColumnSelector::Special(SpecialColumn::FileBase)), + _ => {} + } if let Some(stripped) = token.strip_prefix('-') { if stripped.is_empty() { bail!("column selector '-' must include an index"); @@ -535,6 +562,10 @@ fn resolve_selector_index( .with_context(|| format!("column '{}' not found", name))?; Ok(index) } + ColumnSelector::Special(special) => bail!( + "special column '{}' not supported without column injection", + special.default_header() + ), ColumnSelector::Range(_, _) => { bail!("unexpected nested column range") } @@ -543,7 +574,10 @@ fn resolve_selector_index( #[cfg(test)] mod tests { - use super::{ColumnSelector, parse_selector_list, parse_single_selector, resolve_selectors}; + use super::{ + ColumnSelector, SpecialColumn, parse_selector_list, parse_single_selector, + resolve_selectors, + }; #[test] fn resolves_name_range() { @@ -657,4 +691,15 @@ mod tests { let err = parse_selector_list("`foo").unwrap_err(); assert!(err.to_string().contains("unterminated backtick")); } + + #[test] + fn distinguishes_injected_and_literal_file_columns() { + let selectors = parse_selector_list("__file__,{__file__},`__base__`").unwrap(); + assert!(matches!( + selectors[0], + ColumnSelector::Special(SpecialColumn::FilePath) + )); + assert!(matches!(selectors[1], ColumnSelector::Name(ref name) if name == "__file__")); + assert!(matches!(selectors[2], ColumnSelector::Name(ref name) if name == "__base__")); + } } diff --git a/src/cut.rs b/src/cut.rs index 50c5c44..f162696 100644 --- a/src/cut.rs +++ b/src/cut.rs @@ -1,12 +1,12 @@ use std::io::{self, BufWriter, Write}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; -use anyhow::{Context, Result}; +use anyhow::{Context, Result, bail}; use clap::Args; use crate::common::{ - InputOptions, default_headers, parse_selector_list, reader_for_path, resolve_selectors, - should_skip_record, + ColumnSelector, InputOptions, SpecialColumn, default_headers, parse_selector_list, + reader_for_path, resolve_selectors, should_skip_record, }; #[derive(Args, Debug)] @@ -15,14 +15,18 @@ use crate::common::{ long_about = "Pick columns by name or 1-based index. Combine comma-separated selectors with ranges (colA:colD or 2:6) and single fields in one spec. Defaults to header-aware mode; add -H for headerless input.\n\nExamples:\n tsvkit cut -f id,sample3,sample1 examples/profiles.tsv\n tsvkit cut -f 'Purity,sample:FN,F1' examples/profiles.tsv\n tsvkit cut -H -f 3,1 data.tsv" )] pub struct CutArgs { - /// Input TSV file (use '-' for stdin; supports gz/xz) - #[arg(value_name = "FILE", default_value = "-")] - pub file: PathBuf, + /// Input TSV file(s) (use '-' for stdin; supports gz/xz) + #[arg(value_name = "FILES", num_args = 0.., default_values = ["-"])] + pub files: Vec, /// Fields to select, using names, 1-based indices, ranges (`colA:colD`, `2:5`), or mixes. Comma-separated list. #[arg(short = 'f', long = "fields", value_name = "COLS", required = true)] pub fields: String, + /// Rename the injected file column when using `__file__` or `__base__` + #[arg(long = "file-col", visible_alias = "fc", value_name = "NAME")] + pub file_col: Option, + /// Treat the input as headerless (columns referenced by 1-based indices) #[arg(short = 'H', long = "no-header")] pub no_header: bool, @@ -52,74 +56,136 @@ pub fn run(args: CutArgs) -> Result<()> { args.ignore_empty_row, args.ignore_illegal_row, )?; - let mut reader = reader_for_path(&args.file, args.no_header, &input_opts)?; let mut writer = BufWriter::new(io::stdout().lock()); - if args.no_header { - let mut records = reader.records(); - let first_record = loop { - match records.next() { - Some(record) => { - let record = - record.with_context(|| format!("failed reading from {:?}", args.file))?; - if should_skip_record(&record, &input_opts, None) { - continue; - } - break record; + let file_column_config = FileColumnConfig::new(args.file_col.as_deref()); + let mut header_emitted = false; + + for path in &args.files { + let mut reader = reader_for_path(path, args.no_header, &input_opts)?; + let file_info = FileInfo::from_path(path); + + if args.no_header { + process_no_header_file( + &mut reader, + path, + &selectors, + &file_info, + &input_opts, + &mut writer, + )?; + } else { + process_header_file( + &mut reader, + path, + &selectors, + &file_info, + &file_column_config, + &input_opts, + &mut writer, + &mut header_emitted, + )?; + } + } + + writer.flush()?; + Ok(()) +} + +fn process_no_header_file( + reader: &mut csv::Reader>, + path: &Path, + selectors: &[ColumnSelector], + file_info: &FileInfo, + input_opts: &InputOptions, + writer: &mut BufWriter>, +) -> Result<()> { + let mut records = reader.records(); + let first_record = loop { + match records.next() { + Some(record) => { + let record = record.with_context(|| format!("failed reading from {:?}", path))?; + if should_skip_record(&record, input_opts, None) { + continue; } - None => return Ok(()), - } - }; - let expected_width = first_record.len(); - let headers = default_headers(expected_width); - let indices = resolve_selectors(&headers, &selectors, true)?; - emit_record(&first_record, &indices, &mut writer)?; - for record in records { - let record = record.with_context(|| format!("failed reading from {:?}", args.file))?; - if should_skip_record(&record, &input_opts, Some(expected_width)) { - continue; + break record; } - emit_record(&record, &indices, &mut writer)?; + None => return Ok(()), } - } else { - let headers = reader - .headers() - .with_context(|| format!("failed reading header from {:?}", args.file))? - .iter() - .map(|s| s.to_string()) - .collect::>(); - let indices = resolve_selectors(&headers, &selectors, false)?; - let expected_width = headers.len(); + }; + let expected_width = first_record.len(); + let headers = default_headers(expected_width); + let columns = build_cut_columns(&headers, selectors, true)?; + emit_record(&first_record, &columns, file_info, writer)?; + for record in records { + let record = record.with_context(|| format!("failed reading from {:?}", path))?; + if should_skip_record(&record, input_opts, Some(expected_width)) { + continue; + } + emit_record(&record, &columns, file_info, writer)?; + } + Ok(()) +} - let header_fields: Vec<&str> = indices +fn process_header_file( + reader: &mut csv::Reader>, + path: &Path, + selectors: &[ColumnSelector], + file_info: &FileInfo, + file_column_config: &FileColumnConfig, + input_opts: &InputOptions, + writer: &mut BufWriter>, + header_emitted: &mut bool, +) -> Result<()> { + let headers = reader + .headers() + .with_context(|| format!("failed reading header from {:?}", path))? + .iter() + .map(|s| s.to_string()) + .collect::>(); + let columns = build_cut_columns(&headers, selectors, false)?; + let expected_width = headers.len(); + + if !*header_emitted { + let header_fields: Vec = columns .iter() - .map(|&idx| headers.get(idx).map(|s| s.as_str()).unwrap_or("")) + .map(|column| match column { + CutColumn::Index(idx) => headers + .get(*idx) + .map(|s| s.as_str()) + .unwrap_or("") + .to_string(), + CutColumn::Injected(special) => file_column_config.header_for(*special), + }) .collect(); if !header_fields.is_empty() { writeln!(writer, "{}", header_fields.join("\t"))?; } + *header_emitted = true; + } - for record in reader.records() { - let record = record.with_context(|| format!("failed reading from {:?}", args.file))?; - if should_skip_record(&record, &input_opts, Some(expected_width)) { - continue; - } - emit_record(&record, &indices, &mut writer)?; + for record in reader.records() { + let record = record.with_context(|| format!("failed reading from {:?}", path))?; + if should_skip_record(&record, input_opts, Some(expected_width)) { + continue; } + emit_record(&record, &columns, file_info, writer)?; } - - writer.flush()?; Ok(()) } fn emit_record( record: &csv::StringRecord, - indices: &[usize], + columns: &[CutColumn], + file_info: &FileInfo, writer: &mut BufWriter>, ) -> Result<()> { - let mut fields = Vec::with_capacity(indices.len()); - for &idx in indices { - fields.push(record.get(idx).unwrap_or("")); + let mut fields = Vec::with_capacity(columns.len()); + for column in columns { + match column { + CutColumn::Index(idx) => fields.push(record.get(*idx).unwrap_or("")), + CutColumn::Injected(special) => fields.push(file_info.value_for(*special)), + } } if !fields.is_empty() { writeln!(writer, "{}", fields.join("\t"))?; @@ -128,3 +194,89 @@ fn emit_record( } Ok(()) } + +fn build_cut_columns( + headers: &[String], + selectors: &[ColumnSelector], + no_header: bool, +) -> Result> { + let mut columns = Vec::new(); + for selector in selectors { + match selector { + ColumnSelector::Special(special) => columns.push(CutColumn::Injected(*special)), + ColumnSelector::Range(start, end) => { + if start + .as_deref() + .map_or(false, |sel| matches!(sel, ColumnSelector::Special(_))) + || end + .as_deref() + .map_or(false, |sel| matches!(sel, ColumnSelector::Special(_))) + { + bail!("special columns cannot be used within a range selector"); + } + let indices = resolve_selectors(headers, &[selector.clone()], no_header)?; + columns.extend(indices.into_iter().map(CutColumn::Index)); + } + _ => { + let indices = resolve_selectors(headers, &[selector.clone()], no_header)?; + columns.extend(indices.into_iter().map(CutColumn::Index)); + } + } + } + Ok(columns) +} + +#[derive(Clone)] +struct FileInfo { + path: String, + base: String, +} + +impl FileInfo { + fn from_path(path: &Path) -> Self { + if path == Path::new("-") { + return FileInfo { + path: "-".to_string(), + base: "-".to_string(), + }; + } + let path_str = path.to_string_lossy().into_owned(); + let base = path + .file_name() + .map(|s| s.to_string_lossy().into_owned()) + .unwrap_or_else(|| path_str.clone()); + FileInfo { + path: path_str, + base, + } + } + + fn value_for(&self, special: SpecialColumn) -> &str { + match special { + SpecialColumn::FilePath => self.path.as_str(), + SpecialColumn::FileBase => self.base.as_str(), + } + } +} + +struct FileColumnConfig<'a> { + rename: Option<&'a str>, +} + +impl<'a> FileColumnConfig<'a> { + fn new(rename: Option<&'a str>) -> Self { + FileColumnConfig { rename } + } + + fn header_for(&self, special: SpecialColumn) -> String { + match self.rename { + Some(name) => name.to_string(), + None => special.default_header().to_string(), + } + } +} + +enum CutColumn { + Index(usize), + Injected(SpecialColumn), +} From 91861570b3df62debd5b9c6132ce021c53a28471 Mon Sep 17 00:00:00 2001 From: "Z.-L. Deng" Date: Thu, 16 Oct 2025 05:20:05 +0800 Subject: [PATCH 9/9] Update README.md escape "|" in table --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3be8153..4962d49 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,7 @@ The same expression language powers `filter -e`, `mutate -e name=EXPR`, and rege | `+ - * / ^` | Arithmetic operators (`^` is exponentiation, right-associative). | Numbers | | `== != < <= > >=` | Comparisons. | Numbers or strings | | `&` / `and` | Logical AND. | Booleans | -| `|` / `or` | Logical OR. | Booleans | +| `\|` / `or` | Logical OR. | Booleans | | `!` / `not` | Logical negation. | Booleans | | `~` | Regex match. Right-hand side can be literal text or a `$range`. | Strings | | `!~` | Regex does *not* match. | Strings | @@ -243,7 +243,7 @@ tsvkit filter -e '$group == "case" & $purity >= 0.94' examples/samples.tsv | Literals | `1.25`, `"case"` | Strings use double quotes; escape inner quotes with `\"`. | | Arithmetic | `($rna_ug - $dna_ug) / $rna_ug` | Standard precedence applies (parentheses for clarity). | | Comparisons | `$purity >= 0.9`, `$group != "control"` | Works on numeric or string data. | -| Logical | `($purity >= 0.9) & ($group == "case")` | `&`, `|`, and `!` (or `and`, `or`, `not`). | +| Logical | `($purity >= 0.9) & ($group == "case")` | `&`, `\|`, and `!` (or `and`, `or`, `not`). | | Numeric functions | `log2($total)`, `sqrt($reads)` | See [Expression language essentials](#expression-language-essentials). | | Row-wise aggregators | `sum($dna_ug:$rna_ug)`, `mode($1,$3)`, `countunique($gene:)` | Same catalog as [`summarize`](#summarize): totals, quantiles (`q*` / `p*`), variance/SD, products, entropy, argmin/argmax, membership stats. Works with ranges, lists, and open selectors. | | Regex match | `$tech ~ "sRNA"`, `$notes !~ "(?i)fail"` | Patterns follow Rust `regex` syntax. `(?i)` enables case-insensitive matching. |