diff --git a/bigframes/core/compile/sqlglot/expressions/comparison_ops.py b/bigframes/core/compile/sqlglot/expressions/comparison_ops.py index 8fda3b80dd..8c201f6a06 100644 --- a/bigframes/core/compile/sqlglot/expressions/comparison_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/comparison_ops.py @@ -16,11 +16,13 @@ import typing +import bigframes_vendored.sqlglot as sg import bigframes_vendored.sqlglot.expressions as sge import pandas as pd from bigframes import dtypes from bigframes import operations as ops +from bigframes.core.compile.sqlglot import sqlglot_ir from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler @@ -62,6 +64,10 @@ def _(expr: TypedExpr, op: ops.IsInOp) -> sge.Expression: @register_binary_op(ops.eq_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + if sqlglot_ir._is_null_literal(left.expr): + return sge.Is(this=right.expr, expression=sge.Null()) + if sqlglot_ir._is_null_literal(right.expr): + return sge.Is(this=left.expr, expression=sge.Null()) left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) return sge.EQ(this=left_expr, expression=right_expr) @@ -139,6 +145,17 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.ne_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + if sqlglot_ir._is_null_literal(left.expr): + return sge.Is( + this=sge.paren(right.expr, copy=False), + expression=sg.not_(sge.Null(), copy=False), + ) + if sqlglot_ir._is_null_literal(right.expr): + return sge.Is( + this=sge.paren(left.expr, copy=False), + expression=sg.not_(sge.Null(), copy=False), + ) + left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) return sge.NEQ(this=left_expr, expression=right_expr) diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 4a2a5fb213..ec0d0b3b34 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -19,7 +19,7 @@ from bigframes import dtypes from bigframes import operations as ops -from bigframes.core.compile.sqlglot import sqlglot_types +from bigframes.core.compile.sqlglot import sqlglot_ir, sqlglot_types from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler @@ -101,11 +101,23 @@ def _(expr: TypedExpr) -> sge.Expression: def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression: if len(op.mappings) == 0: return expr.expr + + mappings = [ + ( + sqlglot_ir._literal(key, dtypes.is_compatible(key, expr.dtype)), + sqlglot_ir._literal(value, dtypes.is_compatible(value, expr.dtype)), + ) + for key, value in op.mappings + ] return sge.Case( - this=expr.expr, ifs=[ - sge.If(this=sge.convert(key), true=sge.convert(value)) - for key, value in op.mappings + sge.If( + this=sge.EQ(this=expr.expr, expression=key) + if not sqlglot_ir._is_null_literal(key) + else sge.Is(this=expr.expr, expression=sge.Null()), + true=value, + ) + for key, value in mappings ], default=expr.expr, ) diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 9445b65e99..cefe983e24 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -642,6 +642,15 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select: return new_select_expr +def _is_null_literal(expr: sge.Expression) -> bool: + """Checks if the given expression is a NULL literal.""" + if isinstance(expr, sge.Null): + return True + if isinstance(expr, sge.Cast) and isinstance(expr.this, sge.Null): + return True + return False + + def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None if sqlglot_type is None: diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_eq_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_eq_numeric/out.sql index 9c7c19e61c..a21e008941 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_eq_numeric/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_eq_numeric/out.sql @@ -29,7 +29,7 @@ WITH `bfcte_0` AS ( `bfcol_16` AS `bfcol_26`, `bfcol_17` AS `bfcol_27`, `bfcol_18` AS `bfcol_28`, - `bfcol_15` = CAST(`bfcol_16` AS INT64) AS `bfcol_29` + `bfcol_15` IS NULL AS `bfcol_29` FROM `bfcte_2` ), `bfcte_4` AS ( SELECT @@ -40,15 +40,28 @@ WITH `bfcte_0` AS ( `bfcol_27` AS `bfcol_39`, `bfcol_28` AS `bfcol_40`, `bfcol_29` AS `bfcol_41`, - CAST(`bfcol_26` AS INT64) = `bfcol_25` AS `bfcol_42` + `bfcol_25` = CAST(`bfcol_26` AS INT64) AS `bfcol_42` FROM `bfcte_3` +), `bfcte_5` AS ( + SELECT + *, + `bfcol_36` AS `bfcol_50`, + `bfcol_37` AS `bfcol_51`, + `bfcol_38` AS `bfcol_52`, + `bfcol_39` AS `bfcol_53`, + `bfcol_40` AS `bfcol_54`, + `bfcol_41` AS `bfcol_55`, + `bfcol_42` AS `bfcol_56`, + CAST(`bfcol_38` AS INT64) = `bfcol_37` AS `bfcol_57` + FROM `bfcte_4` ) SELECT - `bfcol_36` AS `rowindex`, - `bfcol_37` AS `int64_col`, - `bfcol_38` AS `bool_col`, - `bfcol_39` AS `int_ne_int`, - `bfcol_40` AS `int_ne_1`, - `bfcol_41` AS `int_ne_bool`, - `bfcol_42` AS `bool_ne_int` -FROM `bfcte_4` \ No newline at end of file + `bfcol_50` AS `rowindex`, + `bfcol_51` AS `int64_col`, + `bfcol_52` AS `bool_col`, + `bfcol_53` AS `int_eq_int`, + `bfcol_54` AS `int_eq_1`, + `bfcol_55` AS `int_eq_null`, + `bfcol_56` AS `int_eq_bool`, + `bfcol_57` AS `bool_eq_int` +FROM `bfcte_5` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_ne_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_ne_numeric/out.sql index 417d24aa72..1a1ff6e44d 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_ne_numeric/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_ne_numeric/out.sql @@ -29,7 +29,9 @@ WITH `bfcte_0` AS ( `bfcol_16` AS `bfcol_26`, `bfcol_17` AS `bfcol_27`, `bfcol_18` AS `bfcol_28`, - `bfcol_15` <> CAST(`bfcol_16` AS INT64) AS `bfcol_29` + ( + `bfcol_15` + ) IS NOT NULL AS `bfcol_29` FROM `bfcte_2` ), `bfcte_4` AS ( SELECT @@ -40,15 +42,28 @@ WITH `bfcte_0` AS ( `bfcol_27` AS `bfcol_39`, `bfcol_28` AS `bfcol_40`, `bfcol_29` AS `bfcol_41`, - CAST(`bfcol_26` AS INT64) <> `bfcol_25` AS `bfcol_42` + `bfcol_25` <> CAST(`bfcol_26` AS INT64) AS `bfcol_42` FROM `bfcte_3` +), `bfcte_5` AS ( + SELECT + *, + `bfcol_36` AS `bfcol_50`, + `bfcol_37` AS `bfcol_51`, + `bfcol_38` AS `bfcol_52`, + `bfcol_39` AS `bfcol_53`, + `bfcol_40` AS `bfcol_54`, + `bfcol_41` AS `bfcol_55`, + `bfcol_42` AS `bfcol_56`, + CAST(`bfcol_38` AS INT64) <> `bfcol_37` AS `bfcol_57` + FROM `bfcte_4` ) SELECT - `bfcol_36` AS `rowindex`, - `bfcol_37` AS `int64_col`, - `bfcol_38` AS `bool_col`, - `bfcol_39` AS `int_ne_int`, - `bfcol_40` AS `int_ne_1`, - `bfcol_41` AS `int_ne_bool`, - `bfcol_42` AS `bool_ne_int` -FROM `bfcte_4` \ No newline at end of file + `bfcol_50` AS `rowindex`, + `bfcol_51` AS `int64_col`, + `bfcol_52` AS `bool_col`, + `bfcol_53` AS `int_ne_int`, + `bfcol_54` AS `int_ne_1`, + `bfcol_55` AS `int_ne_null`, + `bfcol_56` AS `int_ne_bool`, + `bfcol_57` AS `bool_ne_int` +FROM `bfcte_5` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_map/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_map/out.sql index 22628c6a4b..49eada2230 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_map/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_map/out.sql @@ -5,7 +5,13 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - CASE `string_col` WHEN 'value1' THEN 'mapped1' ELSE `string_col` END AS `bfcol_1` + CASE + WHEN `string_col` = 'value1' + THEN 'mapped1' + WHEN `string_col` IS NULL + THEN 'UNKNOWN' + ELSE `string_col` + END AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/test_comparison_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_comparison_ops.py index ea94bcae56..3c13bc798b 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_comparison_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_comparison_ops.py @@ -59,11 +59,12 @@ def test_eq_null_match(scalar_types_df: bpd.DataFrame, snapshot): def test_eq_numeric(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["int64_col", "bool_col"]] - bf_df["int_ne_int"] = bf_df["int64_col"] == bf_df["int64_col"] - bf_df["int_ne_1"] = bf_df["int64_col"] == 1 + bf_df["int_eq_int"] = bf_df["int64_col"] == bf_df["int64_col"] + bf_df["int_eq_1"] = bf_df["int64_col"] == 1 + bf_df["int_eq_null"] = bf_df["int64_col"] == pd.NA - bf_df["int_ne_bool"] = bf_df["int64_col"] == bf_df["bool_col"] - bf_df["bool_ne_int"] = bf_df["bool_col"] == bf_df["int64_col"] + bf_df["int_eq_bool"] = bf_df["int64_col"] == bf_df["bool_col"] + bf_df["bool_eq_int"] = bf_df["bool_col"] == bf_df["int64_col"] snapshot.assert_match(bf_df.sql, "out.sql") @@ -135,6 +136,7 @@ def test_ne_numeric(scalar_types_df: bpd.DataFrame, snapshot): bf_df["int_ne_int"] = bf_df["int64_col"] != bf_df["int64_col"] bf_df["int_ne_1"] = bf_df["int64_col"] != 1 + bf_df["int_ne_null"] = bf_df["int64_col"] != pd.NA bf_df["int_ne_bool"] = bf_df["int64_col"] != bf_df["bool_col"] bf_df["bool_ne_int"] = bf_df["bool_col"] != bf_df["int64_col"] diff --git a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py index 5657874eb5..03b517096e 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pandas as pd import pytest from bigframes import dtypes @@ -342,7 +343,11 @@ def test_map(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[[col_name]] sql = utils._apply_ops_to_sql( bf_df, - [ops.MapOp(mappings=(("value1", "mapped1"),)).as_expr(col_name)], + [ + ops.MapOp(mappings=(("value1", "mapped1"), (pd.NA, "UNKNOWN"))).as_expr( + col_name + ) + ], [col_name], )