Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions bigframes/core/compile/sqlglot/expressions/comparison_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

import typing

import bigframes_vendored.sqlglot as sg
import bigframes_vendored.sqlglot.expressions as sge
import pandas as pd

from bigframes import dtypes
from bigframes import operations as ops
from bigframes.core.compile.sqlglot import sqlglot_ir
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler

Expand Down Expand Up @@ -62,6 +64,10 @@ def _(expr: TypedExpr, op: ops.IsInOp) -> sge.Expression:

@register_binary_op(ops.eq_op)
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
if sqlglot_ir._is_null_literal(left.expr):
return sge.Is(this=right.expr, expression=sge.Null())
if sqlglot_ir._is_null_literal(right.expr):
return sge.Is(this=left.expr, expression=sge.Null())
left_expr = _coerce_bool_to_int(left)
right_expr = _coerce_bool_to_int(right)
return sge.EQ(this=left_expr, expression=right_expr)
Expand Down Expand Up @@ -139,6 +145,17 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:

@register_binary_op(ops.ne_op)
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
if sqlglot_ir._is_null_literal(left.expr):
return sge.Is(
this=sge.paren(right.expr, copy=False),
expression=sg.not_(sge.Null(), copy=False),
)
if sqlglot_ir._is_null_literal(right.expr):
return sge.Is(
this=sge.paren(left.expr, copy=False),
expression=sg.not_(sge.Null(), copy=False),
)

left_expr = _coerce_bool_to_int(left)
right_expr = _coerce_bool_to_int(right)
return sge.NEQ(this=left_expr, expression=right_expr)
Expand Down
20 changes: 16 additions & 4 deletions bigframes/core/compile/sqlglot/expressions/generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from bigframes import dtypes
from bigframes import operations as ops
from bigframes.core.compile.sqlglot import sqlglot_types
from bigframes.core.compile.sqlglot import sqlglot_ir, sqlglot_types
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler

Expand Down Expand Up @@ -101,11 +101,23 @@ def _(expr: TypedExpr) -> sge.Expression:
def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression:
if len(op.mappings) == 0:
return expr.expr

mappings = [
(
sqlglot_ir._literal(key, dtypes.is_compatible(key, expr.dtype)),
sqlglot_ir._literal(value, dtypes.is_compatible(value, expr.dtype)),
)
for key, value in op.mappings
]
return sge.Case(
this=expr.expr,
ifs=[
sge.If(this=sge.convert(key), true=sge.convert(value))
for key, value in op.mappings
sge.If(
this=sge.EQ(this=expr.expr, expression=key)
if not sqlglot_ir._is_null_literal(key)
else sge.Is(this=expr.expr, expression=sge.Null()),
true=value,
)
for key, value in mappings
],
default=expr.expr,
)
Expand Down
9 changes: 9 additions & 0 deletions bigframes/core/compile/sqlglot/sqlglot_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,15 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
return new_select_expr


def _is_null_literal(expr: sge.Expression) -> bool:
"""Checks if the given expression is a NULL literal."""
if isinstance(expr, sge.Null):
return True
if isinstance(expr, sge.Cast) and isinstance(expr.this, sge.Null):
return True
return False


def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None
if sqlglot_type is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ WITH `bfcte_0` AS (
`bfcol_16` AS `bfcol_26`,
`bfcol_17` AS `bfcol_27`,
`bfcol_18` AS `bfcol_28`,
`bfcol_15` = CAST(`bfcol_16` AS INT64) AS `bfcol_29`
`bfcol_15` IS NULL AS `bfcol_29`
FROM `bfcte_2`
), `bfcte_4` AS (
SELECT
Expand All @@ -40,15 +40,28 @@ WITH `bfcte_0` AS (
`bfcol_27` AS `bfcol_39`,
`bfcol_28` AS `bfcol_40`,
`bfcol_29` AS `bfcol_41`,
CAST(`bfcol_26` AS INT64) = `bfcol_25` AS `bfcol_42`
`bfcol_25` = CAST(`bfcol_26` AS INT64) AS `bfcol_42`
FROM `bfcte_3`
), `bfcte_5` AS (
SELECT
*,
`bfcol_36` AS `bfcol_50`,
`bfcol_37` AS `bfcol_51`,
`bfcol_38` AS `bfcol_52`,
`bfcol_39` AS `bfcol_53`,
`bfcol_40` AS `bfcol_54`,
`bfcol_41` AS `bfcol_55`,
`bfcol_42` AS `bfcol_56`,
CAST(`bfcol_38` AS INT64) = `bfcol_37` AS `bfcol_57`
FROM `bfcte_4`
)
SELECT
`bfcol_36` AS `rowindex`,
`bfcol_37` AS `int64_col`,
`bfcol_38` AS `bool_col`,
`bfcol_39` AS `int_ne_int`,
`bfcol_40` AS `int_ne_1`,
`bfcol_41` AS `int_ne_bool`,
`bfcol_42` AS `bool_ne_int`
FROM `bfcte_4`
`bfcol_50` AS `rowindex`,
`bfcol_51` AS `int64_col`,
`bfcol_52` AS `bool_col`,
`bfcol_53` AS `int_eq_int`,
`bfcol_54` AS `int_eq_1`,
`bfcol_55` AS `int_eq_null`,
`bfcol_56` AS `int_eq_bool`,
`bfcol_57` AS `bool_eq_int`
FROM `bfcte_5`
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ WITH `bfcte_0` AS (
`bfcol_16` AS `bfcol_26`,
`bfcol_17` AS `bfcol_27`,
`bfcol_18` AS `bfcol_28`,
`bfcol_15` <> CAST(`bfcol_16` AS INT64) AS `bfcol_29`
(
`bfcol_15`
) IS NOT NULL AS `bfcol_29`
FROM `bfcte_2`
), `bfcte_4` AS (
SELECT
Expand All @@ -40,15 +42,28 @@ WITH `bfcte_0` AS (
`bfcol_27` AS `bfcol_39`,
`bfcol_28` AS `bfcol_40`,
`bfcol_29` AS `bfcol_41`,
CAST(`bfcol_26` AS INT64) <> `bfcol_25` AS `bfcol_42`
`bfcol_25` <> CAST(`bfcol_26` AS INT64) AS `bfcol_42`
FROM `bfcte_3`
), `bfcte_5` AS (
SELECT
*,
`bfcol_36` AS `bfcol_50`,
`bfcol_37` AS `bfcol_51`,
`bfcol_38` AS `bfcol_52`,
`bfcol_39` AS `bfcol_53`,
`bfcol_40` AS `bfcol_54`,
`bfcol_41` AS `bfcol_55`,
`bfcol_42` AS `bfcol_56`,
CAST(`bfcol_38` AS INT64) <> `bfcol_37` AS `bfcol_57`
FROM `bfcte_4`
)
SELECT
`bfcol_36` AS `rowindex`,
`bfcol_37` AS `int64_col`,
`bfcol_38` AS `bool_col`,
`bfcol_39` AS `int_ne_int`,
`bfcol_40` AS `int_ne_1`,
`bfcol_41` AS `int_ne_bool`,
`bfcol_42` AS `bool_ne_int`
FROM `bfcte_4`
`bfcol_50` AS `rowindex`,
`bfcol_51` AS `int64_col`,
`bfcol_52` AS `bool_col`,
`bfcol_53` AS `int_ne_int`,
`bfcol_54` AS `int_ne_1`,
`bfcol_55` AS `int_ne_null`,
`bfcol_56` AS `int_ne_bool`,
`bfcol_57` AS `bool_ne_int`
FROM `bfcte_5`
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ WITH `bfcte_0` AS (
), `bfcte_1` AS (
SELECT
*,
CASE `string_col` WHEN 'value1' THEN 'mapped1' ELSE `string_col` END AS `bfcol_1`
CASE
WHEN `string_col` = 'value1'
THEN 'mapped1'
WHEN `string_col` IS NULL
THEN 'UNKNOWN'
ELSE `string_col`
END AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ def test_eq_null_match(scalar_types_df: bpd.DataFrame, snapshot):
def test_eq_numeric(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["int64_col", "bool_col"]]

bf_df["int_ne_int"] = bf_df["int64_col"] == bf_df["int64_col"]
bf_df["int_ne_1"] = bf_df["int64_col"] == 1
bf_df["int_eq_int"] = bf_df["int64_col"] == bf_df["int64_col"]
bf_df["int_eq_1"] = bf_df["int64_col"] == 1
bf_df["int_eq_null"] = bf_df["int64_col"] == pd.NA

bf_df["int_ne_bool"] = bf_df["int64_col"] == bf_df["bool_col"]
bf_df["bool_ne_int"] = bf_df["bool_col"] == bf_df["int64_col"]
bf_df["int_eq_bool"] = bf_df["int64_col"] == bf_df["bool_col"]
bf_df["bool_eq_int"] = bf_df["bool_col"] == bf_df["int64_col"]

snapshot.assert_match(bf_df.sql, "out.sql")

Expand Down Expand Up @@ -135,6 +136,7 @@ def test_ne_numeric(scalar_types_df: bpd.DataFrame, snapshot):

bf_df["int_ne_int"] = bf_df["int64_col"] != bf_df["int64_col"]
bf_df["int_ne_1"] = bf_df["int64_col"] != 1
bf_df["int_ne_null"] = bf_df["int64_col"] != pd.NA

bf_df["int_ne_bool"] = bf_df["int64_col"] != bf_df["bool_col"]
bf_df["bool_ne_int"] = bf_df["bool_col"] != bf_df["int64_col"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pandas as pd
import pytest

from bigframes import dtypes
Expand Down Expand Up @@ -342,7 +343,11 @@ def test_map(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[[col_name]]
sql = utils._apply_ops_to_sql(
bf_df,
[ops.MapOp(mappings=(("value1", "mapped1"),)).as_expr(col_name)],
[
ops.MapOp(mappings=(("value1", "mapped1"), (pd.NA, "UNKNOWN"))).as_expr(
col_name
)
],
[col_name],
)

Expand Down
Loading