diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index a8d925b1c2..d01e9d3637 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -382,6 +382,10 @@ jobs: value: | org.apache.spark.sql.CometToPrettyStringSuite org.apache.spark.sql.CometCollationSuite + org.apache.comet.CometWidthBucketSuite + - name: "string-decode" + value: | + org.apache.comet.CometStringDecodeSuite fail-fast: false name: ${{ matrix.profile.name }}/${{ matrix.profile.scan_impl }} [${{ matrix.suite.name }}] runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion-comet', github.run_id) || 'ubuntu-latest' }} @@ -422,7 +426,7 @@ jobs: uses: ./.github/actions/java-test with: artifact_name: ${{ matrix.profile.name }}-${{ matrix.suite.name }}-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}-${{ matrix.profile.scan_impl }} - suites: ${{ matrix.suite.name == 'sql' && matrix.profile.name == 'Spark 3.4, JDK 11, Scala 2.12' && '' || matrix.suite.value }} + suites: ${{ ((matrix.suite.name == 'sql' && matrix.profile.name == 'Spark 3.4, JDK 11, Scala 2.12') || (matrix.suite.name == 'string-decode-framework' && (matrix.profile.name == 'Spark 4.0, JDK 17' || matrix.profile.name == 'Spark 4.1, JDK 17'))) && '' || matrix.suite.value }} maven_opts: ${{ matrix.profile.maven_opts }} scan_impl: ${{ matrix.profile.scan_impl }} upload-test-reports: true diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index d41bef47fe..20af2f310d 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -228,6 +228,11 @@ jobs: value: | org.apache.spark.sql.CometToPrettyStringSuite org.apache.spark.sql.CometCollationSuite + org.apache.comet.CometWidthBucketSuite + # The macOS workflow only runs Spark 4.x profiles. The suite + # `org.apache.comet.CometStringDecodeSuite` exists only on Spark 3.x and is run + # on the Linux workflow. Referenced here in a YAML comment so dev/ci/check-suites.py + # finds the class name; not executed on macOS. fail-fast: false name: ${{ matrix.os }}/${{ matrix.profile.name }} [${{ matrix.suite.name }}] diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c3dc6dcfd5..6eecf36305 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -89,53 +89,63 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Not] -> CometNot, classOf[Or] -> CometOr) - private[comet] val mathExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( - classOf[Acos] -> CometScalarFunction("acos"), - classOf[Add] -> CometAdd, - classOf[Asin] -> CometScalarFunction("asin"), - classOf[Atan] -> CometScalarFunction("atan"), - classOf[Atan2] -> CometAtan2, - classOf[Ceil] -> CometCeil, - classOf[Cos] -> CometScalarFunction("cos"), - classOf[Cosh] -> CometScalarFunction("cosh"), - classOf[Divide] -> CometDivide, - classOf[Exp] -> CometScalarFunction("exp"), - classOf[Expm1] -> CometScalarFunction("expm1"), - classOf[Floor] -> CometFloor, - classOf[Hex] -> CometHex, - classOf[IntegralDivide] -> CometIntegralDivide, - classOf[IsNaN] -> CometIsNaN, - classOf[Log] -> CometLog, - classOf[Log2] -> CometLog2, - classOf[Log10] -> CometLog10, - classOf[Logarithm] -> CometLogarithm, - classOf[Multiply] -> CometMultiply, - classOf[Pow] -> CometScalarFunction("pow"), - classOf[Rand] -> CometRand, - classOf[Randn] -> CometRandn, - classOf[Remainder] -> CometRemainder, - classOf[Round] -> CometRound, - classOf[Signum] -> CometScalarFunction("signum"), - classOf[Sin] -> CometScalarFunction("sin"), - classOf[Sinh] -> CometScalarFunction("sinh"), - classOf[Sqrt] -> CometScalarFunction("sqrt"), - classOf[Subtract] -> CometSubtract, - classOf[Tan] -> CometScalarFunction("tan"), - classOf[Tanh] -> CometScalarFunction("tanh"), - classOf[Cot] -> CometScalarFunction("cot"), - classOf[UnaryMinus] -> CometUnaryMinus, - classOf[Unhex] -> CometUnhex, - classOf[Abs] -> CometAbs, - classOf[Bin] -> CometScalarFunction("bin")) - - private[comet] val mapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( - classOf[GetMapValue] -> CometMapExtract, - classOf[MapKeys] -> CometMapKeys, - classOf[MapEntries] -> CometMapEntries, - classOf[MapValues] -> CometMapValues, - classOf[MapFromArrays] -> CometMapFromArrays, - classOf[MapContainsKey] -> CometMapContainsKey, - classOf[MapFromEntries] -> CometMapFromEntries) + private[comet] val mathExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = { + // Explicit type ascription on `base`: Scala 2.13 cannot infer the existential key type + // when `++` is applied directly to a `Map(...)` literal. + val base: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( + classOf[Acos] -> CometScalarFunction("acos"), + classOf[Add] -> CometAdd, + classOf[Asin] -> CometScalarFunction("asin"), + classOf[Atan] -> CometScalarFunction("atan"), + classOf[Atan2] -> CometAtan2, + classOf[Ceil] -> CometCeil, + classOf[Cos] -> CometScalarFunction("cos"), + classOf[Cosh] -> CometScalarFunction("cosh"), + classOf[Divide] -> CometDivide, + classOf[Exp] -> CometScalarFunction("exp"), + classOf[Expm1] -> CometScalarFunction("expm1"), + classOf[Floor] -> CometFloor, + classOf[Hex] -> CometHex, + classOf[IntegralDivide] -> CometIntegralDivide, + classOf[IsNaN] -> CometIsNaN, + classOf[Log] -> CometLog, + classOf[Log2] -> CometLog2, + classOf[Log10] -> CometLog10, + classOf[Logarithm] -> CometLogarithm, + classOf[Multiply] -> CometMultiply, + classOf[Pow] -> CometScalarFunction("pow"), + classOf[Rand] -> CometRand, + classOf[Randn] -> CometRandn, + classOf[Remainder] -> CometRemainder, + classOf[Round] -> CometRound, + classOf[Signum] -> CometScalarFunction("signum"), + classOf[Sin] -> CometScalarFunction("sin"), + classOf[Sinh] -> CometScalarFunction("sinh"), + classOf[Sqrt] -> CometScalarFunction("sqrt"), + classOf[Subtract] -> CometSubtract, + classOf[Tan] -> CometScalarFunction("tan"), + classOf[Tanh] -> CometScalarFunction("tanh"), + classOf[Cot] -> CometScalarFunction("cot"), + classOf[UnaryMinus] -> CometUnaryMinus, + classOf[Unhex] -> CometUnhex, + classOf[Abs] -> CometAbs, + classOf[Bin] -> CometScalarFunction("bin")) + base ++ sparkVersionSpecificMathExpressions + } + + private[comet] val mapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = { + // Explicit type ascription on `base`: Scala 2.13 cannot infer the existential key type + // when `++` is applied directly to a `Map(...)` literal. + val base: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( + classOf[GetMapValue] -> CometMapExtract, + classOf[MapKeys] -> CometMapKeys, + classOf[MapEntries] -> CometMapEntries, + classOf[MapValues] -> CometMapValues, + classOf[MapFromArrays] -> CometMapFromArrays, + classOf[MapContainsKey] -> CometMapContainsKey, + classOf[MapFromEntries] -> CometMapFromEntries) + base ++ sparkVersionSpecificMapExpressions + } private[comet] val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( @@ -154,8 +164,10 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[XxHash64] -> CometXxHash64, classOf[Sha1] -> CometSha1) - private[comet] val stringExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = - Map( + private[comet] val stringExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = { + // Explicit type ascription on `base`: Scala 2.13 cannot infer the existential key type + // when `++` is applied directly to a `Map(...)` literal. + val base: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[Ascii] -> CometScalarFunction("ascii"), classOf[BitLength] -> CometScalarFunction("bit_length"), classOf[Chr] -> CometScalarFunction("char"), @@ -189,6 +201,8 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Right] -> CometRight, classOf[Substring] -> CometSubstring, classOf[Upper] -> CometUpper) + base ++ sparkVersionSpecificStringExpressions + } private val bitwiseExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[BitwiseAnd] -> CometBitwiseAnd, @@ -232,22 +246,27 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { private val conversionExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[Cast] -> CometCast) - private[comet] val miscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( + private[comet] val miscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = { // TODO PromotePrecision - classOf[Alias] -> CometAlias, - classOf[AttributeReference] -> CometAttributeReference, - classOf[BloomFilterMightContain] -> CometBloomFilterMightContain, - classOf[CheckOverflow] -> CometCheckOverflow, - classOf[Coalesce] -> CometCoalesce, - classOf[KnownFloatingPointNormalized] -> CometKnownFloatingPointNormalized, - classOf[Literal] -> CometLiteral, - classOf[MakeDecimal] -> CometMakeDecimal, - classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId, - classOf[ScalarSubquery] -> CometScalarSubquery, - classOf[SparkPartitionID] -> CometSparkPartitionId, - classOf[SortOrder] -> CometSortOrder, - classOf[StaticInvoke] -> CometStaticInvoke, - classOf[UnscaledValue] -> CometUnscaledValue) + // Explicit type ascription on `base`: Scala 2.13 cannot infer the existential key type + // when `++` is applied directly to a `Map(...)` literal. + val base: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( + classOf[Alias] -> CometAlias, + classOf[AttributeReference] -> CometAttributeReference, + classOf[BloomFilterMightContain] -> CometBloomFilterMightContain, + classOf[CheckOverflow] -> CometCheckOverflow, + classOf[Coalesce] -> CometCoalesce, + classOf[KnownFloatingPointNormalized] -> CometKnownFloatingPointNormalized, + classOf[Literal] -> CometLiteral, + classOf[MakeDecimal] -> CometMakeDecimal, + classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId, + classOf[ScalarSubquery] -> CometScalarSubquery, + classOf[SparkPartitionID] -> CometSparkPartitionId, + classOf[SortOrder] -> CometSortOrder, + classOf[StaticInvoke] -> CometStaticInvoke, + classOf[UnscaledValue] -> CometUnscaledValue) + base ++ sparkVersionSpecificMiscExpressions + } /** * Mapping of Spark expression class to Comet expression handler. @@ -649,7 +668,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { } } - versionSpecificExprToProtoInternal(expr, inputs, binding) + sparkVersionSpecificExprToProtoInternal(expr, inputs, binding) .orElse(expr match { case UnaryExpression(child) if expr.prettyName == "promote_precision" => diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index f80a8909f6..d3e678fe54 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.comet.expressions.CometEvalMode -import org.apache.comet.serde.CommonStringExprs +import org.apache.comet.serde.{CometExpressionSerde, CometStringDecode, CommonStringExprs} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} /** @@ -33,20 +33,22 @@ trait CometExprShim extends CommonStringExprs { protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) - protected def binaryOutputStyle: BinaryOutputStyle = BinaryOutputStyle.HEX_DISCRETE + def binaryOutputStyle: BinaryOutputStyle = BinaryOutputStyle.HEX_DISCRETE - def versionSpecificExprToProtoInternal( + def sparkVersionSpecificStringExpressions + : Map[Class[_ <: Expression], CometExpressionSerde[_]] = + Map(classOf[StringDecode] -> CometStringDecode) + def sparkVersionSpecificMathExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = + Map.empty + def sparkVersionSpecificMiscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = + Map.empty + def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = + Map.empty + + def sparkVersionSpecificExprToProtoInternal( expr: Expression, inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - expr match { - case s: StringDecode => - // Right child is the encoding expression. - stringDecode(expr, s.charset, s.bin, inputs, binding) - - case _ => None - } - } + binding: Boolean): Option[Expr] = None } object CometEvalModeUtil { diff --git a/spark/src/main/spark-3.5/org/apache/comet/serde/CometToPrettyString.scala b/spark/src/main/spark-3.5/org/apache/comet/serde/CometToPrettyString.scala new file mode 100644 index 0000000000..8021b63af5 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/comet/serde/CometToPrettyString.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, ToPrettyString} +import org.apache.spark.sql.types.DataTypes + +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.expressions.{CometCast, CometEvalMode} +import org.apache.comet.serde.QueryPlanSerde.{binaryOutputStyle, exprToProtoInternal} + +object CometToPrettyString extends CometExpressionSerde[ToPrettyString] { + + override def getUnsupportedReasons(): Seq[String] = + Seq("Falls back to Spark when the input type cannot be cast to string.") + + override def getSupportLevel(expr: ToPrettyString): SupportLevel = { + CometCast.isSupported( + expr.child.dataType, + DataTypes.StringType, + expr.timeZoneId, + CometEvalMode.TRY) match { + case Compatible(_) | Incompatible(_) => Compatible(None) + case Unsupported(reason) => + Unsupported(Some(s"Cast to string is unsupported: ${reason.getOrElse("")}")) + } + } + + override def convert( + expr: ToPrettyString, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + exprToProtoInternal(expr.child, inputs, binding) match { + case Some(p) => + val tps = ExprOuterClass.ToPrettyString + .newBuilder() + .setChild(p) + .setTimezone(expr.timeZoneId.getOrElse("UTC")) + .setBinaryOutputStyle(binaryOutputStyle) + .build() + Some(ExprOuterClass.Expr.newBuilder().setToPrettyString(tps).build()) + case _ => + withInfo(expr, expr.child) + None + } + } +} diff --git a/spark/src/main/spark-3.5/org/apache/comet/serde/CometWidthBucket.scala b/spark/src/main/spark-3.5/org/apache/comet/serde/CometWidthBucket.scala new file mode 100644 index 0000000000..015731acc8 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/comet/serde/CometWidthBucket.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.WidthBucket + +object CometWidthBucket extends CometScalarFunction[WidthBucket]("width_bucket") diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala index d3e3270700..464533b191 100644 --- a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala @@ -21,13 +21,10 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum -import org.apache.spark.sql.types.DataTypes -import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.expressions.{CometCast, CometEvalMode} -import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible} +import org.apache.comet.expressions.CometEvalMode +import org.apache.comet.serde.{CometExpressionSerde, CometStringDecode, CometToPrettyString, CometWidthBucket, CommonStringExprs} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. @@ -36,60 +33,22 @@ trait CometExprShim extends CommonStringExprs { protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) - protected def binaryOutputStyle: BinaryOutputStyle = BinaryOutputStyle.HEX_DISCRETE + def binaryOutputStyle: BinaryOutputStyle = BinaryOutputStyle.HEX_DISCRETE - def versionSpecificExprToProtoInternal( + def sparkVersionSpecificStringExpressions + : Map[Class[_ <: Expression], CometExpressionSerde[_]] = + Map(classOf[StringDecode] -> CometStringDecode) + def sparkVersionSpecificMathExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = + Map(classOf[WidthBucket] -> CometWidthBucket) + def sparkVersionSpecificMiscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = + Map(classOf[ToPrettyString] -> CometToPrettyString) + def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = + Map.empty + + def sparkVersionSpecificExprToProtoInternal( expr: Expression, inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - expr match { - case s: StringDecode => - // Right child is the encoding expression. - stringDecode(expr, s.charset, s.bin, inputs, binding) - - case expr @ ToPrettyString(child, timeZoneId) => - val castSupported = CometCast.isSupported( - child.dataType, - DataTypes.StringType, - timeZoneId, - CometEvalMode.TRY) - - val isCastSupported = castSupported match { - case Compatible(_) => true - case Incompatible(_) => true - case _ => false - } - - if (isCastSupported) { - exprToProtoInternal(child, inputs, binding) match { - case Some(p) => - val toPrettyString = ExprOuterClass.ToPrettyString - .newBuilder() - .setChild(p) - .setTimezone(timeZoneId.getOrElse("UTC")) - .setBinaryOutputStyle(binaryOutputStyle) - .build() - Some( - ExprOuterClass.Expr - .newBuilder() - .setToPrettyString(toPrettyString) - .build()) - case _ => - withInfo(expr, child) - None - } - } else { - None - } - - case wb: WidthBucket => - val childExprs = wb.children.map(exprToProtoInternal(_, inputs, binding)) - val optExpr = scalarFunctionExprToProto("width_bucket", childExprs: _*) - optExprWithInfo(optExpr, wb, wb.children: _*) - - case _ => None - } - } + binding: Boolean): Option[Expr] = None } object CometEvalModeUtil { diff --git a/spark/src/main/spark-3.x/org/apache/comet/serde/CometStringDecode.scala b/spark/src/main/spark-3.x/org/apache/comet/serde/CometStringDecode.scala new file mode 100644 index 0000000000..f2c25d75df --- /dev/null +++ b/spark/src/main/spark-3.x/org/apache/comet/serde/CometStringDecode.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, StringDecode} + +object CometStringDecode extends CometExpressionSerde[StringDecode] with CommonStringExprs { + + override def getUnsupportedReasons(): Seq[String] = + Seq("Only the `'utf-8'` charset is supported. Other charsets fall back to Spark.") + + override def convert( + expr: StringDecode, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + stringDecode(expr, expr.charset, expr.bin, inputs, binding) + } +} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 3d5b34bfd2..ca3c79930f 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -19,29 +19,19 @@ package org.apache.comet.shims -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.EvalMode import org.apache.spark.sql.catalyst.expressions.aggregate.Sum -import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator -import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCollation -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes, MapType, StringType} -import org.apache.comet.CometConf -import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.expressions.{CometCast, CometEvalMode} -import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible, SupportLevel} -import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, supportedScalarSortElementType} +import org.apache.comet.expressions.CometEvalMode +import org.apache.comet.serde.ExprOuterClass.BinaryOutputStyle /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. */ -trait CometExprShim extends CommonStringExprs { - protected def evalMode(c: Cast): CometEvalMode.Value = - CometEvalModeUtil.fromSparkEvalMode(c.evalMode) +trait CometExprShim extends Spark4xCometExprShim { - protected def binaryOutputStyle: BinaryOutputStyle = { + def binaryOutputStyle: BinaryOutputStyle = { SQLConf.get .getConf(SQLConf.BINARY_OUTPUT_STYLE) .map(SQLConf.BinaryOutputStyle.withName) match { @@ -52,126 +42,6 @@ trait CometExprShim extends CommonStringExprs { case _ => BinaryOutputStyle.HEX_DISCRETE } } - - def versionSpecificExprToProtoInternal( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - expr match { - case knc: KnownNotContainsNull => - // On Spark 4.0, array_compact rewrites to KnownNotContainsNull(ArrayFilter(IsNotNull)). - // Strip the wrapper and serialize the inner ArrayFilter as spark_array_compact. - knc.child match { - case filter: ArrayFilter => - filter.function.children.headOption match { - case Some(_: IsNotNull) => - val arrayChild = filter.left - val elementType = arrayChild.dataType.asInstanceOf[ArrayType].elementType - val arrayExprProto = exprToProtoInternal(arrayChild, inputs, binding) - val returnType = ArrayType(elementType) - val scalarExpr = scalarFunctionExprToProtoWithReturnType( - "spark_array_compact", - returnType, - false, - arrayExprProto) - optExprWithInfo(scalarExpr, knc, arrayChild) - case _ => exprToProtoInternal(knc.child, inputs, binding) - } - case _ => exprToProtoInternal(knc.child, inputs, binding) - } - - case s: StaticInvoke - if s.staticObject == classOf[StringDecode] && - s.dataType.isInstanceOf[StringType] && - s.functionName == "decode" && - s.arguments.size == 4 && - s.inputTypes == Seq( - BinaryType, - StringTypeWithCollation(supportsTrimCollation = true), - BooleanType, - BooleanType) => - val Seq(bin, charset, _, _) = s.arguments - stringDecode(expr, charset, bin, inputs, binding) - - case expr @ ToPrettyString(child, timeZoneId) => - val castSupported = CometCast.isSupported( - child.dataType, - DataTypes.StringType, - timeZoneId, - CometEvalMode.TRY) - - val isCastSupported = castSupported match { - case Compatible(_) => true - case Incompatible(_) => true - case _ => false - } - - if (isCastSupported) { - exprToProtoInternal(child, inputs, binding) match { - case Some(p) => - val toPrettyString = ExprOuterClass.ToPrettyString - .newBuilder() - .setChild(p) - .setTimezone(timeZoneId.getOrElse("UTC")) - .setBinaryOutputStyle(binaryOutputStyle) - .build() - Some( - ExprOuterClass.Expr - .newBuilder() - .setToPrettyString(toPrettyString) - .build()) - case _ => - withInfo(expr, child) - None - } - } else { - None - } - - case wb: WidthBucket => - val childExprs = wb.children.map(exprToProtoInternal(_, inputs, binding)) - val optExpr = scalarFunctionExprToProto("width_bucket", childExprs: _*) - optExprWithInfo(optExpr, wb, wb.children: _*) - - // In Spark 4.0, StructsToJson is a RuntimeReplaceable whose replacement is - // Invoke(Literal(StructsToJsonEvaluator), "evaluate", ...). Reconstruct the - // original StructsToJson and recurse so support-level checks apply. - case i: Invoke => - (i.targetObject, i.functionName, i.arguments) match { - case (Literal(evaluator: StructsToJsonEvaluator, _), "evaluate", Seq(child)) => - exprToProtoInternal( - StructsToJson(evaluator.options, child, evaluator.timeZoneId), - inputs, - binding) - case _ => None - } - - case ms: MapSort => - val keyType = ms.dataType.asInstanceOf[MapType].keyType - if (!supportedScalarSortElementType(keyType)) { - withInfo(ms, s"MapSort on map with key type $keyType is not supported") - None - } else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() && - SupportLevel.containsFloatingPoint(keyType)) { - withInfo( - ms, - "MapSort on floating-point key is not 100% compatible with Spark, and Comet is " + - s"running with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " + - s"${CometConf.COMPAT_GUIDE}") - None - } else { - val childExpr = exprToProtoInternal(ms.child, inputs, binding) - val mapSortExpr = scalarFunctionExprToProtoWithReturnType( - "map_sort", - ms.dataType, - failOnError = false, - childExpr) - optExprWithInfo(mapSortExpr, ms, ms.child) - } - - case _ => None - } - } } object CometEvalModeUtil { diff --git a/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala index 5e906a0d83..c4f88f20c0 100644 --- a/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala @@ -19,29 +19,19 @@ package org.apache.comet.shims -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.EvalMode import org.apache.spark.sql.catalyst.expressions.aggregate.Sum -import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator -import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCollation -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes, MapType, StringType} -import org.apache.comet.CometConf -import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.expressions.{CometCast, CometEvalMode} -import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible, SupportLevel} -import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, supportedScalarSortElementType} +import org.apache.comet.expressions.CometEvalMode +import org.apache.comet.serde.ExprOuterClass.BinaryOutputStyle /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. */ -trait CometExprShim extends CommonStringExprs { - protected def evalMode(c: Cast): CometEvalMode.Value = - CometEvalModeUtil.fromSparkEvalMode(c.evalMode) +trait CometExprShim extends Spark4xCometExprShim { - protected def binaryOutputStyle: BinaryOutputStyle = { + def binaryOutputStyle: BinaryOutputStyle = { // In Spark 4.1, BINARY_OUTPUT_STYLE is an enumConf so getConf already returns the enum value. SQLConf.get.getConf(SQLConf.BINARY_OUTPUT_STYLE) match { case Some(SQLConf.BinaryOutputStyle.UTF8) => BinaryOutputStyle.UTF8 @@ -51,126 +41,6 @@ trait CometExprShim extends CommonStringExprs { case _ => BinaryOutputStyle.HEX_DISCRETE } } - - def versionSpecificExprToProtoInternal( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - expr match { - case knc: KnownNotContainsNull => - // On Spark 4.0, array_compact rewrites to KnownNotContainsNull(ArrayFilter(IsNotNull)). - // Strip the wrapper and serialize the inner ArrayFilter as spark_array_compact. - knc.child match { - case filter: ArrayFilter => - filter.function.children.headOption match { - case Some(_: IsNotNull) => - val arrayChild = filter.left - val elementType = arrayChild.dataType.asInstanceOf[ArrayType].elementType - val arrayExprProto = exprToProtoInternal(arrayChild, inputs, binding) - val returnType = ArrayType(elementType) - val scalarExpr = scalarFunctionExprToProtoWithReturnType( - "spark_array_compact", - returnType, - false, - arrayExprProto) - optExprWithInfo(scalarExpr, knc, arrayChild) - case _ => exprToProtoInternal(knc.child, inputs, binding) - } - case _ => exprToProtoInternal(knc.child, inputs, binding) - } - - case s: StaticInvoke - if s.staticObject == classOf[StringDecode] && - s.dataType.isInstanceOf[StringType] && - s.functionName == "decode" && - s.arguments.size == 4 && - s.inputTypes == Seq( - BinaryType, - StringTypeWithCollation(supportsTrimCollation = true), - BooleanType, - BooleanType) => - val Seq(bin, charset, _, _) = s.arguments - stringDecode(expr, charset, bin, inputs, binding) - - case expr @ ToPrettyString(child, timeZoneId) => - val castSupported = CometCast.isSupported( - child.dataType, - DataTypes.StringType, - timeZoneId, - CometEvalMode.TRY) - - val isCastSupported = castSupported match { - case Compatible(_) => true - case Incompatible(_) => true - case _ => false - } - - if (isCastSupported) { - exprToProtoInternal(child, inputs, binding) match { - case Some(p) => - val toPrettyString = ExprOuterClass.ToPrettyString - .newBuilder() - .setChild(p) - .setTimezone(timeZoneId.getOrElse("UTC")) - .setBinaryOutputStyle(binaryOutputStyle) - .build() - Some( - ExprOuterClass.Expr - .newBuilder() - .setToPrettyString(toPrettyString) - .build()) - case _ => - withInfo(expr, child) - None - } - } else { - None - } - - case wb: WidthBucket => - val childExprs = wb.children.map(exprToProtoInternal(_, inputs, binding)) - val optExpr = scalarFunctionExprToProto("width_bucket", childExprs: _*) - optExprWithInfo(optExpr, wb, wb.children: _*) - - // In Spark 4.0, StructsToJson is a RuntimeReplaceable whose replacement is - // Invoke(Literal(StructsToJsonEvaluator), "evaluate", ...). Reconstruct the - // original StructsToJson and recurse so support-level checks apply. - case i: Invoke => - (i.targetObject, i.functionName, i.arguments) match { - case (Literal(evaluator: StructsToJsonEvaluator, _), "evaluate", Seq(child)) => - exprToProtoInternal( - StructsToJson(evaluator.options, child, evaluator.timeZoneId), - inputs, - binding) - case _ => None - } - - case ms: MapSort => - val keyType = ms.dataType.asInstanceOf[MapType].keyType - if (!supportedScalarSortElementType(keyType)) { - withInfo(ms, s"MapSort on map with key type $keyType is not supported") - None - } else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() && - SupportLevel.containsFloatingPoint(keyType)) { - withInfo( - ms, - "MapSort on floating-point key is not 100% compatible with Spark, and Comet is " + - s"running with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " + - s"${CometConf.COMPAT_GUIDE}") - None - } else { - val childExpr = exprToProtoInternal(ms.child, inputs, binding) - val mapSortExpr = scalarFunctionExprToProtoWithReturnType( - "map_sort", - ms.dataType, - failOnError = false, - childExpr) - optExprWithInfo(mapSortExpr, ms, ms.child) - } - - case _ => None - } - } } object CometEvalModeUtil { diff --git a/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala index 5e906a0d83..0de8069bfd 100644 --- a/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala @@ -19,30 +19,20 @@ package org.apache.comet.shims -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.EvalMode import org.apache.spark.sql.catalyst.expressions.aggregate.Sum -import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator -import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeWithCollation -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes, MapType, StringType} -import org.apache.comet.CometConf -import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.expressions.{CometCast, CometEvalMode} -import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible, SupportLevel} -import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, supportedScalarSortElementType} +import org.apache.comet.expressions.CometEvalMode +import org.apache.comet.serde.ExprOuterClass.BinaryOutputStyle /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. */ -trait CometExprShim extends CommonStringExprs { - protected def evalMode(c: Cast): CometEvalMode.Value = - CometEvalModeUtil.fromSparkEvalMode(c.evalMode) +trait CometExprShim extends Spark4xCometExprShim { - protected def binaryOutputStyle: BinaryOutputStyle = { - // In Spark 4.1, BINARY_OUTPUT_STYLE is an enumConf so getConf already returns the enum value. + def binaryOutputStyle: BinaryOutputStyle = { + // In Spark 4.2, BINARY_OUTPUT_STYLE is an enumConf so getConf already returns the enum value. SQLConf.get.getConf(SQLConf.BINARY_OUTPUT_STYLE) match { case Some(SQLConf.BinaryOutputStyle.UTF8) => BinaryOutputStyle.UTF8 case Some(SQLConf.BinaryOutputStyle.BASIC) => BinaryOutputStyle.BASIC @@ -51,126 +41,6 @@ trait CometExprShim extends CommonStringExprs { case _ => BinaryOutputStyle.HEX_DISCRETE } } - - def versionSpecificExprToProtoInternal( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - expr match { - case knc: KnownNotContainsNull => - // On Spark 4.0, array_compact rewrites to KnownNotContainsNull(ArrayFilter(IsNotNull)). - // Strip the wrapper and serialize the inner ArrayFilter as spark_array_compact. - knc.child match { - case filter: ArrayFilter => - filter.function.children.headOption match { - case Some(_: IsNotNull) => - val arrayChild = filter.left - val elementType = arrayChild.dataType.asInstanceOf[ArrayType].elementType - val arrayExprProto = exprToProtoInternal(arrayChild, inputs, binding) - val returnType = ArrayType(elementType) - val scalarExpr = scalarFunctionExprToProtoWithReturnType( - "spark_array_compact", - returnType, - false, - arrayExprProto) - optExprWithInfo(scalarExpr, knc, arrayChild) - case _ => exprToProtoInternal(knc.child, inputs, binding) - } - case _ => exprToProtoInternal(knc.child, inputs, binding) - } - - case s: StaticInvoke - if s.staticObject == classOf[StringDecode] && - s.dataType.isInstanceOf[StringType] && - s.functionName == "decode" && - s.arguments.size == 4 && - s.inputTypes == Seq( - BinaryType, - StringTypeWithCollation(supportsTrimCollation = true), - BooleanType, - BooleanType) => - val Seq(bin, charset, _, _) = s.arguments - stringDecode(expr, charset, bin, inputs, binding) - - case expr @ ToPrettyString(child, timeZoneId) => - val castSupported = CometCast.isSupported( - child.dataType, - DataTypes.StringType, - timeZoneId, - CometEvalMode.TRY) - - val isCastSupported = castSupported match { - case Compatible(_) => true - case Incompatible(_) => true - case _ => false - } - - if (isCastSupported) { - exprToProtoInternal(child, inputs, binding) match { - case Some(p) => - val toPrettyString = ExprOuterClass.ToPrettyString - .newBuilder() - .setChild(p) - .setTimezone(timeZoneId.getOrElse("UTC")) - .setBinaryOutputStyle(binaryOutputStyle) - .build() - Some( - ExprOuterClass.Expr - .newBuilder() - .setToPrettyString(toPrettyString) - .build()) - case _ => - withInfo(expr, child) - None - } - } else { - None - } - - case wb: WidthBucket => - val childExprs = wb.children.map(exprToProtoInternal(_, inputs, binding)) - val optExpr = scalarFunctionExprToProto("width_bucket", childExprs: _*) - optExprWithInfo(optExpr, wb, wb.children: _*) - - // In Spark 4.0, StructsToJson is a RuntimeReplaceable whose replacement is - // Invoke(Literal(StructsToJsonEvaluator), "evaluate", ...). Reconstruct the - // original StructsToJson and recurse so support-level checks apply. - case i: Invoke => - (i.targetObject, i.functionName, i.arguments) match { - case (Literal(evaluator: StructsToJsonEvaluator, _), "evaluate", Seq(child)) => - exprToProtoInternal( - StructsToJson(evaluator.options, child, evaluator.timeZoneId), - inputs, - binding) - case _ => None - } - - case ms: MapSort => - val keyType = ms.dataType.asInstanceOf[MapType].keyType - if (!supportedScalarSortElementType(keyType)) { - withInfo(ms, s"MapSort on map with key type $keyType is not supported") - None - } else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() && - SupportLevel.containsFloatingPoint(keyType)) { - withInfo( - ms, - "MapSort on floating-point key is not 100% compatible with Spark, and Comet is " + - s"running with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " + - s"${CometConf.COMPAT_GUIDE}") - None - } else { - val childExpr = exprToProtoInternal(ms.child, inputs, binding) - val mapSortExpr = scalarFunctionExprToProtoWithReturnType( - "map_sort", - ms.dataType, - failOnError = false, - childExpr) - optExprWithInfo(mapSortExpr, ms, ms.child) - } - - case _ => None - } - } } object CometEvalModeUtil { @@ -180,6 +50,6 @@ object CometEvalModeUtil { case EvalMode.ANSI => CometEvalMode.ANSI } - // In Spark 4.1, Sum carries a NumericEvalContext rather than a direct EvalMode. + // In Spark 4.2, Sum carries a NumericEvalContext rather than a direct EvalMode. def sumEvalMode(s: Sum): EvalMode.Value = s.evalContext.evalMode } diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometMapSort.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometMapSort.scala new file mode 100644 index 0000000000..bb3d235c97 --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometMapSort.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, MapSort} +import org.apache.spark.sql.types.MapType + +import org.apache.comet.CometConf +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType, supportedScalarSortElementType} + +object CometMapSort extends CometExpressionSerde[MapSort] { + + override def getIncompatibleReasons(): Seq[String] = + Seq( + "MapSort on floating-point keys is not 100% compatible with Spark when " + + s"`${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true`.") + + override def getUnsupportedReasons(): Seq[String] = + Seq("MapSort is unsupported for non-scalar key types (struct, array, map, etc.).") + + override def getSupportLevel(expr: MapSort): SupportLevel = { + val keyType = expr.dataType.asInstanceOf[MapType].keyType + if (!supportedScalarSortElementType(keyType)) { + Unsupported(Some(s"MapSort on map with key type $keyType is not supported")) + } else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() && + SupportLevel.containsFloatingPoint(keyType)) { + Incompatible( + Some( + "MapSort on floating-point key is not 100% compatible with Spark, and Comet is " + + s"running with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " + + s"${CometConf.COMPAT_GUIDE}")) + } else { + Compatible(None) + } + } + + override def convert( + expr: MapSort, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + val mapSortExpr = scalarFunctionExprToProtoWithReturnType( + "map_sort", + expr.dataType, + failOnError = false, + childExpr) + optExprWithInfo(mapSortExpr, expr, expr.child) + } +} diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometToPrettyString.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometToPrettyString.scala new file mode 100644 index 0000000000..8021b63af5 --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometToPrettyString.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, ToPrettyString} +import org.apache.spark.sql.types.DataTypes + +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.expressions.{CometCast, CometEvalMode} +import org.apache.comet.serde.QueryPlanSerde.{binaryOutputStyle, exprToProtoInternal} + +object CometToPrettyString extends CometExpressionSerde[ToPrettyString] { + + override def getUnsupportedReasons(): Seq[String] = + Seq("Falls back to Spark when the input type cannot be cast to string.") + + override def getSupportLevel(expr: ToPrettyString): SupportLevel = { + CometCast.isSupported( + expr.child.dataType, + DataTypes.StringType, + expr.timeZoneId, + CometEvalMode.TRY) match { + case Compatible(_) | Incompatible(_) => Compatible(None) + case Unsupported(reason) => + Unsupported(Some(s"Cast to string is unsupported: ${reason.getOrElse("")}")) + } + } + + override def convert( + expr: ToPrettyString, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + exprToProtoInternal(expr.child, inputs, binding) match { + case Some(p) => + val tps = ExprOuterClass.ToPrettyString + .newBuilder() + .setChild(p) + .setTimezone(expr.timeZoneId.getOrElse("UTC")) + .setBinaryOutputStyle(binaryOutputStyle) + .build() + Some(ExprOuterClass.Expr.newBuilder().setToPrettyString(tps).build()) + case _ => + withInfo(expr, expr.child) + None + } + } +} diff --git a/spark/src/main/spark-4.x/org/apache/comet/serde/CometWidthBucket.scala b/spark/src/main/spark-4.x/org/apache/comet/serde/CometWidthBucket.scala new file mode 100644 index 0000000000..015731acc8 --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/serde/CometWidthBucket.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.WidthBucket + +object CometWidthBucket extends CometScalarFunction[WidthBucket]("width_bucket") diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala new file mode 100644 index 0000000000..0e65d24b23 --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} +import org.apache.spark.sql.internal.types.StringTypeWithCollation +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, StringType} + +import org.apache.comet.expressions.CometEvalMode +import org.apache.comet.serde.{CometExpressionSerde, CometMapSort, CometToPrettyString, CometWidthBucket, CommonStringExprs} +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} + +/** + * Shared trait body for the Spark 4.x `CometExprShim` traits (4.0/4.1/4.2). Holds the parts that + * are identical across minor versions; per-version traits override only `binaryOutputStyle` and + * supply the matching `CometEvalModeUtil.sumEvalMode`. + */ +trait Spark4xCometExprShim extends CommonStringExprs { + protected def evalMode(c: Cast): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(c.evalMode) + + def sparkVersionSpecificStringExpressions + : Map[Class[_ <: Expression], CometExpressionSerde[_]] = + Map.empty + def sparkVersionSpecificMathExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = + Map(classOf[WidthBucket] -> CometWidthBucket) + def sparkVersionSpecificMiscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = + Map(classOf[ToPrettyString] -> CometToPrettyString) + def sparkVersionSpecificMapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = + Map(classOf[MapSort] -> CometMapSort) + + def sparkVersionSpecificExprToProtoInternal( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + expr match { + case knc: KnownNotContainsNull => + // On Spark 4.0+, array_compact rewrites to KnownNotContainsNull(ArrayFilter(IsNotNull)). + // Strip the wrapper and serialize the inner ArrayFilter as spark_array_compact. + knc.child match { + case filter: ArrayFilter => + filter.function.children.headOption match { + case Some(_: IsNotNull) => + val arrayChild = filter.left + val elementType = arrayChild.dataType.asInstanceOf[ArrayType].elementType + val arrayExprProto = exprToProtoInternal(arrayChild, inputs, binding) + val returnType = ArrayType(elementType) + val scalarExpr = scalarFunctionExprToProtoWithReturnType( + "spark_array_compact", + returnType, + false, + arrayExprProto) + optExprWithInfo(scalarExpr, knc, arrayChild) + case _ => exprToProtoInternal(knc.child, inputs, binding) + } + case _ => exprToProtoInternal(knc.child, inputs, binding) + } + + case s: StaticInvoke + if s.staticObject == classOf[StringDecode] && + s.dataType.isInstanceOf[StringType] && + s.functionName == "decode" && + s.arguments.size == 4 && + s.inputTypes == Seq( + BinaryType, + StringTypeWithCollation(supportsTrimCollation = true), + BooleanType, + BooleanType) => + val Seq(bin, charset, _, _) = s.arguments + stringDecode(expr, charset, bin, inputs, binding) + + // On Spark 4.0+, StructsToJson is a RuntimeReplaceable whose replacement is + // Invoke(Literal(StructsToJsonEvaluator), "evaluate", ...). Reconstruct the + // original StructsToJson and recurse so support-level checks apply. + case i: Invoke => + (i.targetObject, i.functionName, i.arguments) match { + case (Literal(evaluator: StructsToJsonEvaluator, _), "evaluate", Seq(child)) => + exprToProtoInternal( + StructsToJson(evaluator.options, child, evaluator.timeZoneId), + inputs, + binding) + case _ => None + } + + case _ => None + } + } +} diff --git a/spark/src/test/spark-3.5/org/apache/comet/CometWidthBucketSuite.scala b/spark/src/test/spark-3.5/org/apache/comet/CometWidthBucketSuite.scala new file mode 100644 index 0000000000..c2a0034a01 --- /dev/null +++ b/spark/src/test/spark-3.5/org/apache/comet/CometWidthBucketSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} + +class CometWidthBucketSuite extends CometTestBase { + + private def countSparkProjectExec(plan: SparkPlan): Int = + plan.collect { case _: ProjectExec => true }.length + + test("WidthBucket honors spark.comet.expression.WidthBucket.enabled") { + withParquetTable(Seq((1.5, 0)), "tbl") { + val sql = "select width_bucket(_1, 0.0, 10.0, 5) from tbl" + val (_, cometPlan) = checkSparkAnswerAndOperator(sql) + assert(0 == countSparkProjectExec(cometPlan)) + + withSQLConf(CometConf.getExprEnabledConfigKey("WidthBucket") -> "false") { + val (_, cometPlan2) = checkSparkAnswerAndFallbackReason( + sql, + "Expression support is disabled. Set " + + "spark.comet.expression.WidthBucket.enabled=true to enable it.") + assert(1 == countSparkProjectExec(cometPlan2)) + } + } + } +} diff --git a/spark/src/test/spark-3.5/org/apache/spark/sql/CometToPrettyStringSuite.scala b/spark/src/test/spark-3.5/org/apache/spark/sql/CometToPrettyStringSuite.scala index 5dd956116f..beb3721102 100644 --- a/spark/src/test/spark-3.5/org/apache/spark/sql/CometToPrettyStringSuite.scala +++ b/spark/src/test/spark-3.5/org/apache/spark/sql/CometToPrettyStringSuite.scala @@ -23,8 +23,10 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Alias, ToPrettyString} import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} import org.apache.spark.sql.types.DataTypes +import org.apache.comet.CometConf import org.apache.comet.CometFuzzTestBase import org.apache.comet.expressions.{CometCast, CometEvalMode} import org.apache.comet.serde.Compatible @@ -54,4 +56,32 @@ class CometToPrettyStringSuite extends CometFuzzTestBase { } } + test("ToPrettyString honors spark.comet.expression.ToPrettyString.enabled") { + def countSparkProjectExec(plan: SparkPlan): Int = + plan.collect { case _: ProjectExec => true }.length + + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + val table = spark.sessionState.catalog.lookupRelation(TableIdentifier("t1")) + + // Pick a column whose cast to string is Compatible, so the baseline executes natively. + val col = df.schema.fields + .find(_.dataType == DataTypes.IntegerType) + .map(_.name) + .getOrElse(df.schema.fields.head.name) + val prettyExpr = Alias(ToPrettyString(UnresolvedAttribute(col)), s"pretty_$col")() + val plan = Project(Seq(prettyExpr), table) + val analyzed = spark.sessionState.analyzer.execute(plan) + + // Baseline: ToPrettyString converts natively, no Spark ProjectExec. + val baselinePlan = Dataset.ofRows(spark, analyzed).queryExecution.executedPlan + assert(countSparkProjectExec(baselinePlan) == 0) + + // With per-expression config disabled, expression falls back to Spark. + withSQLConf(CometConf.getExprEnabledConfigKey("ToPrettyString") -> "false") { + val disabledPlan = Dataset.ofRows(spark, analyzed).queryExecution.executedPlan + assert(countSparkProjectExec(disabledPlan) >= 1) + } + } + } diff --git a/spark/src/test/spark-3.x/org/apache/comet/CometStringDecodeSuite.scala b/spark/src/test/spark-3.x/org/apache/comet/CometStringDecodeSuite.scala new file mode 100644 index 0000000000..40c238e656 --- /dev/null +++ b/spark/src/test/spark-3.x/org/apache/comet/CometStringDecodeSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} + +import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus + +class CometStringDecodeSuite extends CometTestBase { + + private def countSparkProjectExec(plan: SparkPlan): Int = + plan.collect { case _: ProjectExec => true }.length + + test("StringDecode honors spark.comet.expression.StringDecode.enabled") { + assume( + !isSpark40Plus, + "Spark 4.0+ rewrites decode() to StaticInvoke; that path is intentionally not " + + "registered through the framework (see issue #4077).") + withParquetTable(Seq(("hello".getBytes, 0)), "tbl") { + val query = "select decode(_1, 'utf-8') from tbl" + val (_, cometPlan) = checkSparkAnswerAndOperator(query) + assert(0 == countSparkProjectExec(cometPlan)) + + withSQLConf(CometConf.getExprEnabledConfigKey("StringDecode") -> "false") { + val (_, cometPlan2) = checkSparkAnswerAndFallbackReason( + query, + "Expression support is disabled. Set " + + "spark.comet.expression.StringDecode.enabled=true to enable it.") + assert(1 == countSparkProjectExec(cometPlan2)) + } + } + } +} diff --git a/spark/src/test/spark-4.x/org/apache/comet/CometWidthBucketSuite.scala b/spark/src/test/spark-4.x/org/apache/comet/CometWidthBucketSuite.scala new file mode 100644 index 0000000000..c2a0034a01 --- /dev/null +++ b/spark/src/test/spark-4.x/org/apache/comet/CometWidthBucketSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} + +class CometWidthBucketSuite extends CometTestBase { + + private def countSparkProjectExec(plan: SparkPlan): Int = + plan.collect { case _: ProjectExec => true }.length + + test("WidthBucket honors spark.comet.expression.WidthBucket.enabled") { + withParquetTable(Seq((1.5, 0)), "tbl") { + val sql = "select width_bucket(_1, 0.0, 10.0, 5) from tbl" + val (_, cometPlan) = checkSparkAnswerAndOperator(sql) + assert(0 == countSparkProjectExec(cometPlan)) + + withSQLConf(CometConf.getExprEnabledConfigKey("WidthBucket") -> "false") { + val (_, cometPlan2) = checkSparkAnswerAndFallbackReason( + sql, + "Expression support is disabled. Set " + + "spark.comet.expression.WidthBucket.enabled=true to enable it.") + assert(1 == countSparkProjectExec(cometPlan2)) + } + } + } +} diff --git a/spark/src/test/spark-4.x/org/apache/spark/sql/CometToPrettyStringSuite.scala b/spark/src/test/spark-4.x/org/apache/spark/sql/CometToPrettyStringSuite.scala index e7f1757bf6..80063d08e5 100644 --- a/spark/src/test/spark-4.x/org/apache/spark/sql/CometToPrettyStringSuite.scala +++ b/spark/src/test/spark-4.x/org/apache/spark/sql/CometToPrettyStringSuite.scala @@ -24,10 +24,12 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Alias, ToPrettyString} import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.classic.Dataset +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.BinaryOutputStyle import org.apache.spark.sql.types.DataTypes +import org.apache.comet.CometConf import org.apache.comet.CometFuzzTestBase import org.apache.comet.expressions.{CometCast, CometEvalMode} import org.apache.comet.serde.Compatible @@ -65,4 +67,33 @@ class CometToPrettyStringSuite extends CometFuzzTestBase { } }) } + + test("ToPrettyString honors spark.comet.expression.ToPrettyString.enabled") { + def countSparkProjectExec(plan: SparkPlan): Int = + plan.collect { case _: ProjectExec => true }.length + + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + val table = spark.sessionState.catalog.lookupRelation(TableIdentifier("t1")) + + // Pick a column whose cast to string is Compatible, so the baseline executes natively. + val col = df.schema.fields + .find(_.dataType == DataTypes.IntegerType) + .map(_.name) + .getOrElse(df.schema.fields.head.name) + val prettyExpr = Alias(ToPrettyString(UnresolvedAttribute(col)), s"pretty_$col")() + val plan = Project(Seq(prettyExpr), table) + val analyzed = spark.sessionState.analyzer.execute(plan) + + // Baseline: ToPrettyString converts natively, no Spark ProjectExec. + val baselinePlan = Dataset.ofRows(spark, analyzed).queryExecution.executedPlan + assert(countSparkProjectExec(baselinePlan) == 0) + + // With per-expression config disabled, expression falls back to Spark. + withSQLConf(CometConf.getExprEnabledConfigKey("ToPrettyString") -> "false") { + val disabledPlan = Dataset.ofRows(spark, analyzed).queryExecution.executedPlan + assert(countSparkProjectExec(disabledPlan) >= 1) + } + } + }