diff --git a/docs/source/contributor-guide/spark_expressions_support.md b/docs/source/contributor-guide/spark_expressions_support.md index cb27a439b9..5d563b94ba 100644 --- a/docs/source/contributor-guide/spark_expressions_support.md +++ b/docs/source/contributor-guide/spark_expressions_support.md @@ -519,7 +519,7 @@ - [x] contains - [ ] decode - [ ] elt -- [ ] encode +- [x] encode - [x] endswith - [ ] find_in_set - [ ] format_number diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 668081d257..dac3969b73 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -63,6 +63,7 @@ of expressions that be disabled. | Concat | | ConcatWs | | Contains | +| Encode | | EndsWith | | InitCap | | Left | diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index aec4b19111..e09b134ed9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -479,7 +479,7 @@ trait CommonStringExprs { binding: Boolean): Option[Expr] = { charset match { case Literal(str, DataTypes.StringType) - if str.toString.toLowerCase(Locale.ROOT) == "utf-8" => + if str != null && str.toString.toLowerCase(Locale.ROOT) == "utf-8" => // decode(col, 'utf-8') can be treated as a cast with "try" eval mode that puts nulls // for invalid strings. // Left child is the binary expression. @@ -495,4 +495,33 @@ trait CommonStringExprs { None } } + + def stringEncode( + expr: Expression, + charset: Expression, + value: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + charset match { + case Literal(str, DataTypes.StringType) + if str != null && str.toString.toLowerCase(Locale.ROOT) == "utf-8" => + // encode(col, 'utf-8') is byte-equivalent to cast(string AS binary) + // because Spark's UTF8String already holds valid UTF-8 bytes. + val strExpr = exprToProtoInternal(value, inputs, binding) + if (strExpr.isDefined) { + CometCast.castToProto( + expr, + None, + DataTypes.BinaryType, + strExpr.get, + CometEvalMode.LEGACY) + } else { + withInfo(expr, value) + None + } + case _ => + withInfo(expr, "Comet only supports encoding with 'utf-8'.") + None + } + } } diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index f80a8909f6..09be02d8e0 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -43,7 +43,8 @@ trait CometExprShim extends CommonStringExprs { case s: StringDecode => // Right child is the encoding expression. stringDecode(expr, s.charset, s.bin, inputs, binding) - + case e: Encode => + stringEncode(expr, e.charset, e.value, inputs, binding) case _ => None } } diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala index d3e3270700..2b095249fd 100644 --- a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala @@ -46,7 +46,8 @@ trait CometExprShim extends CommonStringExprs { case s: StringDecode => // Right child is the encoding expression. stringDecode(expr, s.charset, s.bin, inputs, binding) - + case e: Encode => + stringEncode(expr, e.charset, e.value, inputs, binding) case expr @ ToPrettyString(child, timeZoneId) => val castSupported = CometCast.isSupported( child.dataType, diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 3d5b34bfd2..085ee8cc9e 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -30,14 +30,14 @@ import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.expressions.{CometCast, CometEvalMode} -import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible, SupportLevel} +import org.apache.comet.serde.{Compatible, ExprOuterClass, Incompatible, SupportLevel} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, supportedScalarSortElementType} /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. */ -trait CometExprShim extends CommonStringExprs { +trait CometExprShim extends ShimCometExprs { protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) @@ -168,8 +168,7 @@ trait CometExprShim extends CommonStringExprs { childExpr) optExprWithInfo(mapSortExpr, ms, ms.child) } - - case _ => None + case _ => sparkExprToProto(expr, inputs, binding) } } } diff --git a/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala index 5e906a0d83..e8b2b94acb 100644 --- a/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala @@ -30,14 +30,14 @@ import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.expressions.{CometCast, CometEvalMode} -import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible, SupportLevel} +import org.apache.comet.serde.{Compatible, ExprOuterClass, Incompatible, SupportLevel} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, supportedScalarSortElementType} /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. */ -trait CometExprShim extends CommonStringExprs { +trait CometExprShim extends ShimCometExprs { protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) @@ -92,6 +92,19 @@ trait CometExprShim extends CommonStringExprs { val Seq(bin, charset, _, _) = s.arguments stringDecode(expr, charset, bin, inputs, binding) + case s: StaticInvoke + if s.staticObject == classOf[Encode] && + s.dataType.isInstanceOf[BinaryType] && + s.functionName == "encode" && + s.arguments.size == 4 && + s.inputTypes == Seq( + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true), + BooleanType, + BooleanType) => + val Seq(value, charset, _, _) = s.arguments + stringEncode(expr, charset, value, inputs, binding) + case expr @ ToPrettyString(child, timeZoneId) => val castSupported = CometCast.isSupported( child.dataType, @@ -168,7 +181,7 @@ trait CometExprShim extends CommonStringExprs { optExprWithInfo(mapSortExpr, ms, ms.child) } - case _ => None + case _ => sparkExprToProto(expr, inputs, binding) } } } diff --git a/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala index 5e906a0d83..ee0c8d6810 100644 --- a/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala @@ -30,14 +30,14 @@ import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.expressions.{CometCast, CometEvalMode} -import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible, SupportLevel} +import org.apache.comet.serde.{Compatible, ExprOuterClass, Incompatible, SupportLevel} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, supportedScalarSortElementType} /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. */ -trait CometExprShim extends CommonStringExprs { +trait CometExprShim extends ShimCometExprs { protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) @@ -167,8 +167,7 @@ trait CometExprShim extends CommonStringExprs { childExpr) optExprWithInfo(mapSortExpr, ms, ms.child) } - - case _ => None + case _ => sparkExprToProto(expr, inputs, binding) } } } diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/ShimCometExprs.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/ShimCometExprs.scala new file mode 100644 index 0000000000..af43b9e76f --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/ShimCometExprs.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.internal.types.StringTypeWithCollation +import org.apache.spark.sql.types.{BinaryType, BooleanType} + +import org.apache.comet.serde.CommonStringExprs +import org.apache.comet.serde.ExprOuterClass.Expr + +trait ShimCometExprs extends CommonStringExprs { + + protected def sparkExprToProto( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + expr match { + // encode(str, 'utf-8') -> cast(string AS binary) — Arrow's Utf8->Binary + // is a zero-copy reinterpret, matching Spark's UTF8String.getBytes() exactly. + case s: StaticInvoke + if s.staticObject == classOf[Encode] && + s.dataType.isInstanceOf[BinaryType] && + s.functionName == "encode" && + s.arguments.size == 4 && + s.inputTypes == Seq( + StringTypeWithCollation(supportsTrimCollation = true), + StringTypeWithCollation(supportsTrimCollation = true), + BooleanType, + BooleanType) => + val Seq(value, charset, _, _) = s.arguments + stringEncode(expr, charset, value, inputs, binding) + + case _ => None + } + } +} diff --git a/spark/src/test/resources/sql-tests/expressions/string/encode.sql b/spark/src/test/resources/sql-tests/expressions/string/encode.sql new file mode 100644 index 0000000000..2a6a142426 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/encode.sql @@ -0,0 +1,61 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Tests for the SQL `encode(str, charset)` function. +-- +-- Spark 3.x: Encode is a BinaryExpression(value, charset). +-- Spark 4.x+: Encode is RuntimeReplaceable; the analyzer rewrites it to +-- StaticInvoke(classOf[Encode], BinaryType, "encode", ...) + +statement +CREATE TABLE test_encode_utf8(s string) USING parquet + +statement +INSERT INTO test_encode_utf8 VALUES ('hello'), ('world'), (''), ('café'), (NULL) + +query +SELECT encode(s, 'utf-8') FROM test_encode_utf8 + +query +SELECT encode(s, 'UTF-8') FROM test_encode_utf8 + +-- Mixed-case charset literal exercises toLowerCase normalization +query +SELECT encode(s, 'Utf-8') FROM test_encode_utf8 + +query +SELECT encode('hello', 'utf-8'), encode('', 'utf-8'), encode(CAST(NULL AS STRING), 'utf-8') + +-- Different language(French, Japanese) +query +SELECT encode('café', 'utf-8'), encode('日本語', 'utf-8') + +-- non-UTF-8 falls back to Spark JVM +statement +CREATE TABLE test_encode_charset_safe(s string) USING parquet + +statement +INSERT INTO test_encode_charset_safe VALUES ('hello'), ('world'), (''), (NULL) + +query expect_fallback(Comet only supports encoding with 'utf-8'.) +SELECT encode(s, 'UTF-16BE') FROM test_encode_charset_safe + +query expect_fallback(Comet only supports encoding with 'utf-8'.) +SELECT encode(s, 'US-ASCII') FROM test_encode_charset_safe + +query expect_fallback(Comet only supports encoding with 'utf-8'.) +SELECT encode(s, 'ISO-8859-1') FROM test_encode_charset_safe \ No newline at end of file