Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,25 @@ struct JoinParameters {
pub join_type: DFJoinType,
}

/// If `expr` evaluates to `Timestamp(_, Some(_))` against `schema`, wrap it in a
/// metadata-only cast to `Timestamp(_, None)`. This is required because
/// DataFusion's `SortMergeJoinExec` comparator only supports timezone-less
/// timestamp types, while Spark's `TimestampType` serializes as
/// `Timestamp(µs, "UTC")`. The cast preserves ordering on the same time unit.
fn strip_timestamp_tz(
expr: Arc<dyn PhysicalExpr>,
schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
match expr.data_type(schema)? {
DataType::Timestamp(unit, Some(_)) => Ok(Arc::new(CastExpr::new(
expr,
DataType::Timestamp(unit, None),
None,
))),
_ => Ok(expr),
}
}

#[derive(Default)]
pub struct BinaryExprOptions {
pub is_integral_div: bool,
Expand Down Expand Up @@ -1727,10 +1746,23 @@ impl PhysicalPlanner {
let left = Arc::clone(&join_params.left.native_plan);
let right = Arc::clone(&join_params.right.native_plan);

let left_schema = left.schema();
let right_schema = right.schema();
let join_on = join_params
.join_on
.into_iter()
.map(|(l, r)| {
Ok((
strip_timestamp_tz(l, left_schema.as_ref())?,
strip_timestamp_tz(r, right_schema.as_ref())?,
))
})
.collect::<Result<Vec<_>, ExecutionError>>()?;

let join = Arc::new(SortMergeJoinExec::try_new(
Arc::clone(&left),
Arc::clone(&right),
join_params.join_on,
join_on,
join_params.join_filter,
join_params.join_type,
sort_options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType}
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType, TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.io.ChunkedByteBuffer
Expand Down Expand Up @@ -2270,7 +2270,7 @@ object CometSortMergeJoinExec extends CometOperatorSerde[SortMergeJoinExec] {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType =>
true
case TimestampNTZType => true
case TimestampNTZType | _: TimestampType => true
case _ => false
}

Expand Down
154 changes: 146 additions & 8 deletions spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.scalatest.Tag
import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometSortMergeJoinExec}
import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
import org.apache.spark.sql.internal.SQLConf

Expand Down Expand Up @@ -55,21 +55,159 @@ class CometJoinSuite extends CometTestBase {
.toSeq)
}

test("SortMergeJoin with unsupported key type should fall back to Spark") {
test("SortMergeJoin with TimestampType key runs natively") {
withSQLConf(
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
withTable("t1", "t2") {
sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET")
sql("INSERT OVERWRITE t1 VALUES('a', timestamp'2019-01-01 11:11:11')")
sql(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if table t1 is written with one session timezone and table t2 is written with another session timezone.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test to cover this case

"INSERT OVERWRITE t1 VALUES " +
"('a', timestamp'2019-01-01 11:11:11'), " +
"('b', timestamp'2020-05-05 05:05:05')")

sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET")
sql("INSERT OVERWRITE t2 VALUES('a', timestamp'2019-01-01 11:11:11')")
sql(
"INSERT OVERWRITE t2 VALUES " +
"('a', timestamp'2019-01-01 11:11:11'), " +
"('c', timestamp'2021-07-07 07:07:07')")

checkSparkAnswerAndOperator(
sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time"),
Seq(classOf[CometSortMergeJoinExec]))
}
}
}

val df = sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time")
val (sparkPlan, cometPlan) = checkSparkAnswer(df)
assert(sparkPlan.canonicalized === cometPlan.canonicalized)
test("SortMergeJoin with TimestampType key supports outer joins") {
withSQLConf(
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
withTable("t1", "t2") {
sql("CREATE TABLE t1(id INT, time TIMESTAMP) USING PARQUET")
sql(
"INSERT OVERWRITE t1 VALUES " +
"(1, timestamp'2019-01-01 11:11:11'), " +
"(2, timestamp'2020-05-05 05:05:05'), " +
"(3, timestamp'2021-07-07 07:07:07')")

sql("CREATE TABLE t2(id INT, time TIMESTAMP) USING PARQUET")
sql(
"INSERT OVERWRITE t2 VALUES " +
"(10, timestamp'2019-01-01 11:11:11'), " +
"(20, timestamp'2022-02-02 02:02:02')")

for (joinType <- Seq("LEFT OUTER", "RIGHT OUTER", "FULL OUTER")) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

checkSparkAnswerAndOperator(
sql(s"SELECT * FROM t1 $joinType JOIN t2 ON t1.time = t2.time"),
Seq(classOf[CometSortMergeJoinExec]))
}
}
}
}

test("SortMergeJoin with composite (string, timestamp) key runs natively") {
withSQLConf(
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
withTable("t1", "t2") {
sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET")
sql(
"INSERT OVERWRITE t1 VALUES " +
"('a', timestamp'2019-01-01 11:11:11'), " +
"('b', timestamp'2019-01-01 11:11:11'), " +
"('a', timestamp'2020-05-05 05:05:05')")

sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET")
sql(
"INSERT OVERWRITE t2 VALUES " +
"('a', timestamp'2019-01-01 11:11:11'), " +
"('b', timestamp'2020-05-05 05:05:05'), " +
"('a', timestamp'2020-05-05 05:05:05')")

checkSparkAnswerAndOperator(
sql(
"SELECT * FROM t1 JOIN t2 " +
"ON t1.name = t2.name AND t1.time = t2.time"),
Seq(classOf[CometSortMergeJoinExec]))
}
}
}

test("SortMergeJoin with nullable TimestampType key runs natively") {
withSQLConf(
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
withTable("t1", "t2") {
sql("CREATE TABLE t1(id INT, time TIMESTAMP) USING PARQUET")
sql(
"INSERT OVERWRITE t1 VALUES " +
"(1, timestamp'2019-01-01 11:11:11'), " +
"(2, CAST(NULL AS TIMESTAMP)), " +
"(3, timestamp'2020-05-05 05:05:05')")

sql("CREATE TABLE t2(id INT, time TIMESTAMP) USING PARQUET")
sql(
"INSERT OVERWRITE t2 VALUES " +
"(10, timestamp'2019-01-01 11:11:11'), " +
"(20, CAST(NULL AS TIMESTAMP)), " +
"(30, timestamp'2022-02-02 02:02:02')")

// Inner join: NULL = NULL must not match in Spark semantics.
checkSparkAnswerAndOperator(
sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time"),
Seq(classOf[CometSortMergeJoinExec]))

// Full outer join: NULL-keyed rows from both sides surface as unmatched.
checkSparkAnswerAndOperator(
sql("SELECT * FROM t1 FULL OUTER JOIN t2 ON t1.time = t2.time"),
Seq(classOf[CometSortMergeJoinExec]))
}
}
}

test("SortMergeJoin with TimestampType key across mixed write-time session timezones") {
// TimestampType is an instant (UTC microseconds); only the parsing of literal
// strings depends on the session timezone. Writing each side under a different
// session zone with wall-clock literals that resolve to the same UTC instant
// must still produce a join match.
withSQLConf(
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
withTable("t1", "t2") {
// t1 written in America/Los_Angeles. 03:11:11 -0800 == 11:11:11 UTC.
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") {
sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET")
sql(
"INSERT OVERWRITE t1 VALUES " +
"('a', timestamp'2019-01-01 03:11:11'), " +
"('b', timestamp'2020-05-04 22:05:05')")
}

// t2 written in Asia/Tokyo. 20:11:11 +0900 == 11:11:11 UTC, so the 'a' and
// 'a2' rows share a UTC instant with t1's 'a' row.
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Tokyo") {
sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET")
sql(
"INSERT OVERWRITE t2 VALUES " +
"('a', timestamp'2019-01-01 20:11:11'), " +
"('c', timestamp'2021-07-07 16:07:07')")
}

// Read at a third session timezone to confirm the equality is on the
// stored UTC instant rather than the displayed wall-clock value.
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
checkSparkAnswerAndOperator(
sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time"),
Seq(classOf[CometSortMergeJoinExec]))
}
}
}
}
Expand Down
Loading