Skip to content
2 changes: 1 addition & 1 deletion docs/source/contributor-guide/spark_expressions_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@
- [x] contains
- [ ] decode
- [ ] elt
- [ ] encode
- [x] encode
- [x] endswith
- [ ] find_in_set
- [ ] format_number
Expand Down
1 change: 1 addition & 0 deletions docs/source/user-guide/latest/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ of expressions that be disabled.
| Concat |
| ConcatWs |
| Contains |
| Encode |
| EndsWith |
| InitCap |
| Left |
Expand Down
31 changes: 30 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/strings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ trait CommonStringExprs {
binding: Boolean): Option[Expr] = {
charset match {
case Literal(str, DataTypes.StringType)
if str.toString.toLowerCase(Locale.ROOT) == "utf-8" =>
if str != null && str.toString.toLowerCase(Locale.ROOT) == "utf-8" =>
// decode(col, 'utf-8') can be treated as a cast with "try" eval mode that puts nulls
// for invalid strings.
// Left child is the binary expression.
Expand All @@ -495,4 +495,33 @@ trait CommonStringExprs {
None
}
}

def stringEncode(
expr: Expression,
charset: Expression,
value: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[Expr] = {
charset match {
case Literal(str, DataTypes.StringType)
if str != null && str.toString.toLowerCase(Locale.ROOT) == "utf-8" =>
// encode(col, 'utf-8') is byte-equivalent to cast(string AS binary)
// because Spark's UTF8String already holds valid UTF-8 bytes.
val strExpr = exprToProtoInternal(value, inputs, binding)
if (strExpr.isDefined) {
CometCast.castToProto(
expr,
None,
DataTypes.BinaryType,
strExpr.get,
CometEvalMode.LEGACY)
} else {
withInfo(expr, value)
None
}
case _ =>
withInfo(expr, "Comet only supports encoding with 'utf-8'.")
None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ trait CometExprShim extends CommonStringExprs {
case s: StringDecode =>
// Right child is the encoding expression.
stringDecode(expr, s.charset, s.bin, inputs, binding)

case e: Encode =>
stringEncode(expr, e.charset, e.value, inputs, binding)
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ trait CometExprShim extends CommonStringExprs {
case s: StringDecode =>
// Right child is the encoding expression.
stringDecode(expr, s.charset, s.bin, inputs, binding)

case e: Encode =>
stringEncode(expr, e.charset, e.value, inputs, binding)
case expr @ ToPrettyString(child, timeZoneId) =>
val castSupported = CometCast.isSupported(
child.dataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes
import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible, SupportLevel}
import org.apache.comet.serde.{Compatible, ExprOuterClass, Incompatible, SupportLevel}
import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr}
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, supportedScalarSortElementType}

/**
* `CometExprShim` acts as a shim for parsing expressions from different Spark versions.
*/
trait CometExprShim extends CommonStringExprs {
trait CometExprShim extends ShimCometExprs {
protected def evalMode(c: Cast): CometEvalMode.Value =
CometEvalModeUtil.fromSparkEvalMode(c.evalMode)

Expand Down Expand Up @@ -168,8 +168,7 @@ trait CometExprShim extends CommonStringExprs {
childExpr)
optExprWithInfo(mapSortExpr, ms, ms.child)
}

case _ => None
case _ => sparkExprToProto(expr, inputs, binding)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes
import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible, SupportLevel}
import org.apache.comet.serde.{Compatible, ExprOuterClass, Incompatible, SupportLevel}
import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr}
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, supportedScalarSortElementType}

/**
* `CometExprShim` acts as a shim for parsing expressions from different Spark versions.
*/
trait CometExprShim extends CommonStringExprs {
trait CometExprShim extends ShimCometExprs {
protected def evalMode(c: Cast): CometEvalMode.Value =
CometEvalModeUtil.fromSparkEvalMode(c.evalMode)

Expand Down Expand Up @@ -92,6 +92,19 @@ trait CometExprShim extends CommonStringExprs {
val Seq(bin, charset, _, _) = s.arguments
stringDecode(expr, charset, bin, inputs, binding)

case s: StaticInvoke
if s.staticObject == classOf[Encode] &&
s.dataType.isInstanceOf[BinaryType] &&
s.functionName == "encode" &&
s.arguments.size == 4 &&
s.inputTypes == Seq(
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true),
BooleanType,
BooleanType) =>
val Seq(value, charset, _, _) = s.arguments
stringEncode(expr, charset, value, inputs, binding)

case expr @ ToPrettyString(child, timeZoneId) =>
val castSupported = CometCast.isSupported(
child.dataType,
Expand Down Expand Up @@ -168,7 +181,7 @@ trait CometExprShim extends CommonStringExprs {
optExprWithInfo(mapSortExpr, ms, ms.child)
}

case _ => None
case _ => sparkExprToProto(expr, inputs, binding)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes
import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible, SupportLevel}
import org.apache.comet.serde.{Compatible, ExprOuterClass, Incompatible, SupportLevel}
import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr}
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, supportedScalarSortElementType}

/**
* `CometExprShim` acts as a shim for parsing expressions from different Spark versions.
*/
trait CometExprShim extends CommonStringExprs {
trait CometExprShim extends ShimCometExprs {
protected def evalMode(c: Cast): CometEvalMode.Value =
CometEvalModeUtil.fromSparkEvalMode(c.evalMode)

Expand Down Expand Up @@ -167,8 +167,7 @@ trait CometExprShim extends CommonStringExprs {
childExpr)
optExprWithInfo(mapSortExpr, ms, ms.child)
}

case _ => None
case _ => sparkExprToProto(expr, inputs, binding)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.shims

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types.{BinaryType, BooleanType}

import org.apache.comet.serde.CommonStringExprs
import org.apache.comet.serde.ExprOuterClass.Expr

trait ShimCometExprs extends CommonStringExprs {

protected def sparkExprToProto(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[Expr] = {
expr match {
// encode(str, 'utf-8') -> cast(string AS binary) — Arrow's Utf8->Binary
// is a zero-copy reinterpret, matching Spark's UTF8String.getBytes() exactly.
case s: StaticInvoke
if s.staticObject == classOf[Encode] &&
s.dataType.isInstanceOf[BinaryType] &&
s.functionName == "encode" &&
s.arguments.size == 4 &&
s.inputTypes == Seq(
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true),
BooleanType,
BooleanType) =>
val Seq(value, charset, _, _) = s.arguments
stringEncode(expr, charset, value, inputs, binding)

case _ => None
}
}
}
61 changes: 61 additions & 0 deletions spark/src/test/resources/sql-tests/expressions/string/encode.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.

-- Tests for the SQL `encode(str, charset)` function.
--
-- Spark 3.x: Encode is a BinaryExpression(value, charset).
-- Spark 4.x+: Encode is RuntimeReplaceable; the analyzer rewrites it to
-- StaticInvoke(classOf[Encode], BinaryType, "encode", ...)

statement
CREATE TABLE test_encode_utf8(s string) USING parquet

statement
INSERT INTO test_encode_utf8 VALUES ('hello'), ('world'), (''), ('café'), (NULL)

query
SELECT encode(s, 'utf-8') FROM test_encode_utf8

query
SELECT encode(s, 'UTF-8') FROM test_encode_utf8

-- Mixed-case charset literal exercises toLowerCase normalization
query
SELECT encode(s, 'Utf-8') FROM test_encode_utf8

query
SELECT encode('hello', 'utf-8'), encode('', 'utf-8'), encode(CAST(NULL AS STRING), 'utf-8')

-- Different language(French, Japanese)
query
SELECT encode('café', 'utf-8'), encode('日本語', 'utf-8')

-- non-UTF-8 falls back to Spark JVM
statement
CREATE TABLE test_encode_charset_safe(s string) USING parquet

statement
INSERT INTO test_encode_charset_safe VALUES ('hello'), ('world'), (''), (NULL)

query expect_fallback(Comet only supports encoding with 'utf-8'.)
SELECT encode(s, 'UTF-16BE') FROM test_encode_charset_safe

query expect_fallback(Comet only supports encoding with 'utf-8'.)
SELECT encode(s, 'US-ASCII') FROM test_encode_charset_safe

query expect_fallback(Comet only supports encoding with 'utf-8'.)
SELECT encode(s, 'ISO-8859-1') FROM test_encode_charset_safe