diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 3842148a43..8ec790ffb4 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -66,6 +66,7 @@ of expressions that be disabled. | InitCap | | Left | | Length | +| Levenshtein | | Like | | Lower | | OctetLength | diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 0957868a60..32f108fc54 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -196,6 +196,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(spark_map_sort); make_comet_scalar_udf!("spark_map_sort", func, without data_type) } + "levenshtein" => { + let func = Arc::new(crate::string_funcs::spark_levenshtein); + make_comet_scalar_udf!("levenshtein", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/native/spark-expr/src/string_funcs/levenshtein.rs b/native/spark-expr/src/string_funcs/levenshtein.rs new file mode 100644 index 0000000000..9d2c48c7b7 --- /dev/null +++ b/native/spark-expr/src/string_funcs/levenshtein.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Levenshtein distance expression implementation. +//! +//! Computes the Levenshtein edit distance between two strings, +//! matching Apache Spark's `levenshtein(str1, str2)` semantics. + +use arrow::array::{as_string_array, Array, ArrayRef, Int32Array}; +use datafusion::common::{DataFusionError, Result, ScalarValue}; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Computes the Levenshtein edit distance between two UTF-8 strings. +/// +/// This uses the standard dynamic programming algorithm with O(min(m,n)) space. +fn levenshtein_distance(s: &str, t: &str) -> i32 { + let s_chars: Vec = s.chars().collect(); + let t_chars: Vec = t.chars().collect(); + let m = s_chars.len(); + let n = t_chars.len(); + + // Optimization: if one string is empty, distance is the length of the other + if m == 0 { + return n as i32; + } + if n == 0 { + return m as i32; + } + + // Use the shorter string for the "column" to minimize space usage + let (s_chars, t_chars, m, n) = if m > n { + (t_chars, s_chars, n, m) + } else { + (s_chars, t_chars, m, n) + }; + + // Previous and current row of distances + let mut prev = vec![0i32; m + 1]; + let mut curr = vec![0i32; m + 1]; + + // Initialize base case: distance from empty string + for (i, val) in prev.iter_mut().enumerate() { + *val = i as i32; + } + + for j in 1..=n { + curr[0] = j as i32; + for i in 1..=m { + let cost = if s_chars[i - 1] == t_chars[j - 1] { + 0 + } else { + 1 + }; + curr[i] = (prev[i] + 1) // deletion + .min(curr[i - 1] + 1) // insertion + .min(prev[i - 1] + cost); // substitution + } + std::mem::swap(&mut prev, &mut curr); + } + + prev[m] +} + +/// Spark-compatible levenshtein scalar function. +/// +/// Accepts two or three arguments: +/// - `levenshtein(str1, str2)` → edit distance +/// - `levenshtein(str1, str2, threshold)` → edit distance if <= threshold, else -1 +/// +/// NULL inputs produce NULL outputs. NULL threshold produces NULL output. +pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result { + if args.len() < 2 || args.len() > 3 { + return Err(DataFusionError::Internal(format!( + "levenshtein requires 2 or 3 arguments, got {}", + args.len() + ))); + } + + // Extract optional threshold (3rd argument must be a scalar Int32) + let threshold: Option = if args.len() == 3 { + match &args[2] { + ColumnarValue::Scalar(ScalarValue::Int32(t)) => match t { + Some(val) => Some(*val), + None => return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))), + }, + _ => { + return Err(DataFusionError::Internal( + "levenshtein threshold must be an Int32 scalar".to_string(), + )); + } + } + } else { + None + }; + + // Expand scalars to arrays for uniform processing + let len = args + .iter() + .take(2) + .find_map(|arg| match arg { + ColumnarValue::Array(a) => Some(a.len()), + _ => None, + }) + .unwrap_or(1); + + let left = args[0].clone().into_array(len)?; + let right = args[1].clone().into_array(len)?; + + let left_arr = as_string_array(&left); + let right_arr = as_string_array(&right); + + let result: Int32Array = left_arr + .iter() + .zip(right_arr.iter()) + .map(|(l, r)| match (l, r) { + (Some(l), Some(r)) => { + let dist = levenshtein_distance(l, r); + match threshold { + Some(t) if dist > t => Some(-1), + _ => Some(dist), + } + } + _ => None, // NULL propagation + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::StringArray; + + #[test] + fn test_levenshtein_basic() { + assert_eq!(levenshtein_distance("", ""), 0); + assert_eq!(levenshtein_distance("abc", ""), 3); + assert_eq!(levenshtein_distance("", "abc"), 3); + assert_eq!(levenshtein_distance("abc", "abc"), 0); + assert_eq!(levenshtein_distance("kitten", "sitting"), 3); + assert_eq!(levenshtein_distance("frog", "fog"), 1); + } + + #[test] + fn test_levenshtein_unicode() { + // Spark counts character-level (not byte-level) edit distance + assert_eq!(levenshtein_distance("你好", "你坏"), 1); + assert_eq!(levenshtein_distance("abc", "äbc"), 1); + } + + #[test] + fn test_spark_levenshtein_nulls() { + let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("abc"), + None, + Some("hello"), + ]))); + let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("adc"), + Some("test"), + None, + ]))); + + let result = spark_levenshtein(&[left, right]).unwrap(); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.value(0), 1); // abc -> adc = 1 + assert!(int_arr.is_null(1)); // NULL -> test = NULL + assert!(int_arr.is_null(2)); // hello -> NULL = NULL + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_spark_levenshtein_with_threshold() { + let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("kitten"), + Some("abc"), + Some("frog"), + ]))); + let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("sitting"), + Some("adc"), + Some("fog"), + ]))); + let threshold = ColumnarValue::Scalar(ScalarValue::Int32(Some(2))); + + let result = spark_levenshtein(&[left, right, threshold]).unwrap(); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(int_arr.value(0), -1); // kitten->sitting=3 > 2, return -1 + assert_eq!(int_arr.value(1), 1); // abc->adc=1 <= 2, return 1 + assert_eq!(int_arr.value(2), 1); // frog->fog=1 <= 2, return 1 + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_spark_levenshtein_null_threshold() { + let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![Some("abc")]))); + let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![Some("adc")]))); + let threshold = ColumnarValue::Scalar(ScalarValue::Int32(None)); + + let result = spark_levenshtein(&[left, right, threshold]).unwrap(); + match result { + ColumnarValue::Scalar(ScalarValue::Int32(None)) => {} // NULL threshold -> NULL + _ => panic!("Expected NULL scalar result for NULL threshold"), + } + } +} diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index bb785bdb44..ed871ef0da 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -17,10 +17,12 @@ mod contains; mod get_json_object; +mod levenshtein; mod split; mod substring; pub use contains::SparkContains; pub use get_json_object::spark_get_json_object; +pub use levenshtein::spark_levenshtein; pub use split::spark_split; pub use substring::SubstringExpr; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 448c2c2cb3..03e24e46cb 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -167,6 +167,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[GetJsonObject] -> CometGetJsonObject, classOf[InitCap] -> CometInitCap, classOf[Length] -> CometLength, + classOf[Levenshtein] -> CometLevenshtein, classOf[Like] -> CometLike, classOf[Lower] -> CometLower, classOf[OctetLength] -> CometScalarFunction("octet_length"), diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 968fe8cd69..4aef72fe5c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,8 +21,8 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} -import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Levenshtein, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} +import org.apache.spark.sql.types.{BinaryType, DataTypes, IntegerType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf @@ -84,6 +84,30 @@ object CometLength extends CometScalarFunction[Length]("length") { } } +object CometLevenshtein extends CometExpressionSerde[Levenshtein] { + + override def getUnsupportedReasons(): Seq[String] = Seq( + "Non-default collation (non-UTF8_BINARY) is not supported") + + override def getSupportLevel(expr: Levenshtein): SupportLevel = { + expr.children.headOption match { + case Some(child) if QueryPlanSerde.isStringCollationType(child.dataType) => + Unsupported(Some("Levenshtein with non-default collation is not supported")) + case _ => Compatible() + } + } + + override def convert( + expr: Levenshtein, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val childExprs = expr.children.map(exprToProtoInternal(_, inputs, binding)) + val optExpr = + scalarFunctionExprToProtoWithReturnType("levenshtein", IntegerType, false, childExprs: _*) + optExprWithInfo(optExpr, expr, expr.children: _*) + } +} + object CometInitCap extends CometScalarFunction[InitCap]("initcap") { override def getIncompatibleReasons(): Seq[String] = Seq( diff --git a/spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql b/spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql new file mode 100644 index 0000000000..6ae8225cbe --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql @@ -0,0 +1,47 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +statement +CREATE TABLE test_levenshtein(s1 string, s2 string) USING parquet + +statement +INSERT INTO test_levenshtein VALUES ('kitten', 'sitting'), ('frog', 'fog'), ('abc', 'abc'), ('', 'hello'), ('hello', ''), ('', ''), (NULL, 'test'), ('hello', NULL), (NULL, NULL) + +-- column arguments +query +SELECT levenshtein(s1, s2) FROM test_levenshtein + +-- column + literal +query +SELECT levenshtein(s1, 'abc') FROM test_levenshtein + +-- literal + column +query +SELECT levenshtein('kitten', s2) FROM test_levenshtein + +-- literal + literal +query +SELECT levenshtein('kitten', 'sitting'), levenshtein('frog', 'fog'), levenshtein('', ''), levenshtein(NULL, 'a') + +-- identical strings +query +SELECT levenshtein(s1, s1) FROM test_levenshtein + +-- unicode characters +query +SELECT levenshtein('café', 'cafe'), levenshtein('你好', '你坏') + diff --git a/spark/src/test/resources/sql-tests/expressions/string/levenshtein_threshold.sql b/spark/src/test/resources/sql-tests/expressions/string/levenshtein_threshold.sql new file mode 100644 index 0000000000..a89435416d --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/levenshtein_threshold.sql @@ -0,0 +1,40 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- MinSparkVersion: 3.5 + +statement +CREATE TABLE test_levenshtein(s1 string, s2 string) USING parquet + +statement +INSERT INTO test_levenshtein VALUES ('kitten', 'sitting'), ('frog', 'fog'), ('abc', 'abc'), ('', 'hello'), ('hello', ''), ('', ''), (NULL, 'test'), ('hello', NULL), (NULL, NULL) + +-- three argument version with threshold +query +SELECT levenshtein('kitten', 'sitting', 2), levenshtein('kitten', 'sitting', 3), levenshtein('kitten', 'sitting', 4) + +-- threshold with column arguments +query +SELECT levenshtein(s1, s2, 2) FROM test_levenshtein + +-- threshold edge cases +query +SELECT levenshtein('abc', 'abc', 0), levenshtein('abc', 'adc', 0), levenshtein('', '', 0) + +-- threshold with NULL +query +SELECT levenshtein('abc', 'adc', NULL), levenshtein(NULL, 'test', 2) diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 44513efaa0..1d474a9284 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.{CometTestBase, DataFrame} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataTypes, StructField, StructType} +import org.apache.comet.CometSparkSessionExtensions.isSpark35Plus import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} class CometStringExpressionSuite extends CometTestBase { @@ -707,4 +708,46 @@ class CometStringExpressionSuite extends CometTestBase { // scalastyle:on } + test("levenshtein") { + val data = + Seq(("kitten", "sitting"), ("frog", "fog"), ("abc", "abc"), ("", "hello"), ("hello", "")) + + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("SELECT levenshtein(_1, _2) FROM tbl") + } + } + + test("levenshtein with nulls") { + val table = "levenshtein_null_test" + withTable(table) { + sql(s"CREATE TABLE $table(s1 STRING, s2 STRING) USING parquet") + sql( + s"INSERT INTO $table VALUES " + + "('abc', 'adc'), (NULL, 'test'), ('hello', NULL), (NULL, NULL)") + checkSparkAnswerAndOperator(s"SELECT levenshtein(s1, s2) FROM $table") + } + } + + test("levenshtein with unicode") { + val data = Seq( + ("\u4f60\u597d", "\u4f60\u574f"), + ("caf\u00e9", "cafe"), + ("\ud83d\ude00", "\ud83d\ude01")) + + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("SELECT levenshtein(_1, _2) FROM tbl") + } + } + + test("levenshtein with threshold") { + assume(isSpark35Plus, "levenshtein with threshold requires Spark 3.5+") + val data = Seq(("kitten", "sitting"), ("frog", "fog"), ("abc", "abc"), ("hello", "world")) + + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("SELECT levenshtein(_1, _2, 2) FROM tbl") + checkSparkAnswerAndOperator("SELECT levenshtein(_1, _2, 0) FROM tbl") + checkSparkAnswerAndOperator("SELECT levenshtein(_1, _2, 10) FROM tbl") + } + } + } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala index c7c750aed6..a26b6687ca 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala @@ -57,6 +57,10 @@ object CometStringExpressionBenchmark extends CometBenchmarkBase { StringExprConfig("initCap", "select initCap(c1) from parquetV1Table"), StringExprConfig("instr", "select instr(c1, '123') from parquetV1Table"), StringExprConfig("length", "select length(c1) from parquetV1Table"), + StringExprConfig("levenshtein", "select levenshtein(c1, 'test') from parquetV1Table"), + StringExprConfig( + "levenshtein_threshold", + "select levenshtein(c1, 'test', 3) from parquetV1Table"), StringExprConfig("like", "select c1 like '%123%' from parquetV1Table"), StringExprConfig("lower", "select lower(c1) from parquetV1Table"), StringExprConfig("lpad", "select lpad(c1, 150, 'x') from parquetV1Table"),