diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 89bb58d7dd..647e86d28a 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -23,6 +23,7 @@ from bigframes.core import window_spec import bigframes.core.compile.sqlglot.aggregations.op_registration as reg from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present +from bigframes.core.compile.sqlglot.expressions import constants import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr import bigframes.core.compile.sqlglot.sqlglot_ir as ir from bigframes.operations import aggregations as agg_ops @@ -44,9 +45,13 @@ def _( column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: - # BQ will return null for empty column, result would be false in pandas. - result = apply_window_if_present(sge.func("LOGICAL_AND", column.expr), window) - return sge.func("IFNULL", result, sge.true()) + expr = column.expr + if column.dtype != dtypes.BOOL_DTYPE: + expr = sge.NEQ(this=expr, expression=sge.convert(0)) + expr = apply_window_if_present(sge.func("LOGICAL_AND", expr), window) + + # BQ will return null for empty column, result would be true in pandas. + return sge.func("COALESCE", expr, sge.convert(True)) @UNARY_OP_REGISTRATION.register(agg_ops.AnyOp) @@ -56,6 +61,8 @@ def _( window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: expr = column.expr + if column.dtype != dtypes.BOOL_DTYPE: + expr = sge.NEQ(this=expr, expression=sge.convert(0)) expr = apply_window_if_present(sge.func("LOGICAL_OR", expr), window) # BQ will return null for empty column, result would be false in pandas. @@ -326,6 +333,15 @@ def _( unit=sge.Identifier(this="MICROSECOND"), ) + if column.dtype == dtypes.DATE_DTYPE: + date_diff = sge.DateDiff( + this=column.expr, expression=shifted, unit=sge.Identifier(this="DAY") + ) + return sge.Cast( + this=sge.Floor(this=date_diff * constants._DAY_TO_MICROSECONDS), + to="INT64", + ) + raise TypeError(f"Cannot perform diff on type {column.dtype}") @@ -410,24 +426,28 @@ def _( column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: + expr = column.expr + if column.dtype == dtypes.BOOL_DTYPE: + expr = sge.Cast(this=expr, to="INT64") + # Need to short-circuit as log with zeroes is illegal sql - is_zero = sge.EQ(this=column.expr, expression=sge.convert(0)) + is_zero = sge.EQ(this=expr, expression=sge.convert(0)) # There is no product sql aggregate function, so must implement as a sum of logs, and then # apply power after. Note, log and power base must be equal! This impl uses natural log. - logs = ( - sge.Case() - .when(is_zero, sge.convert(0)) - .else_(sge.func("LN", sge.func("ABS", column.expr))) + logs = sge.If( + this=is_zero, + true=sge.convert(0), + false=sge.func("LOG", sge.convert(2), sge.func("ABS", expr)), ) logs_sum = apply_window_if_present(sge.func("SUM", logs), window) - magnitude = sge.func("EXP", logs_sum) + magnitude = sge.func("POWER", sge.convert(2), logs_sum) # Can't determine sign from logs, so have to determine parity of count of negative inputs is_negative = ( sge.Case() .when( - sge.LT(this=sge.func("SIGN", column.expr), expression=sge.convert(0)), + sge.EQ(this=sge.func("SIGN", expr), expression=sge.convert(-1)), sge.convert(1), ) .else_(sge.convert(0)) @@ -445,11 +465,7 @@ def _( .else_( sge.Mul( this=magnitude, - expression=sge.If( - this=sge.EQ(this=negative_count_parity, expression=sge.convert(1)), - true=sge.convert(-1), - false=sge.convert(1), - ), + expression=sge.func("POWER", sge.convert(-1), negative_count_parity), ) ) ) @@ -499,14 +515,18 @@ def _( column: typed_expr.TypedExpr, window: typing.Optional[window_spec.WindowSpec] = None, ) -> sge.Expression: - # TODO: Support interpolation argument - # TODO: Support percentile_disc - result: sge.Expression = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q)) + expr = column.expr + if column.dtype == dtypes.BOOL_DTYPE: + expr = sge.Cast(this=expr, to="INT64") + + result: sge.Expression = sge.func("PERCENTILE_CONT", expr, sge.convert(op.q)) if window is None: - # PERCENTILE_CONT is a navigation function, not an aggregate function, so it always needs an OVER clause. + # PERCENTILE_CONT is a navigation function, not an aggregate function, + # so it always needs an OVER clause. result = sge.Window(this=result) else: result = apply_window_if_present(result, window) + if op.should_floor_result: result = sge.Cast(this=sge.func("FLOOR", result), to="INT64") return result diff --git a/bigframes/core/compile/sqlglot/expressions/constants.py b/bigframes/core/compile/sqlglot/expressions/constants.py index f383306292..5ba4a72279 100644 --- a/bigframes/core/compile/sqlglot/expressions/constants.py +++ b/bigframes/core/compile/sqlglot/expressions/constants.py @@ -20,6 +20,7 @@ _NAN = sge.Cast(this=sge.convert("NaN"), to="FLOAT64") _INF = sge.Cast(this=sge.convert("Infinity"), to="FLOAT64") _NEG_INF = sge.Cast(this=sge.convert("-Infinity"), to="FLOAT64") +_DAY_TO_MICROSECONDS = sge.convert(86400000000) # Approx Highest number you can pass in to EXP function and get a valid FLOAT64 result # FLOAT64 has 11 exponent bits, so max values is about 2**(2**10) diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/out.sql index d31b21f56b..0be2fea80b 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/out.sql @@ -1,12 +1,15 @@ WITH `bfcte_0` AS ( SELECT - `bool_col` + `bool_col`, + `int64_col` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT - COALESCE(LOGICAL_AND(`bool_col`), TRUE) AS `bfcol_1` + COALESCE(LOGICAL_AND(`bool_col`), TRUE) AS `bfcol_2`, + COALESCE(LOGICAL_AND(`int64_col` <> 0), TRUE) AS `bfcol_3` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `bool_col` + `bfcol_2` AS `bool_col`, + `bfcol_3` AS `int64_col` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/window_partition_out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/window_partition_out.sql deleted file mode 100644 index 23357817c1..0000000000 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/window_partition_out.sql +++ /dev/null @@ -1,14 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `bool_col`, - `string_col` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - COALESCE(LOGICAL_AND(`bool_col`) OVER (PARTITION BY `string_col`), TRUE) AS `bfcol_2` - FROM `bfcte_0` -) -SELECT - `bfcol_2` AS `agg_bool` -FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/window_out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all_w_window/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/window_out.sql rename to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all_w_window/out.sql diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_any/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_any/out.sql index 03b0d5c151..ae62e22e36 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_any/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_any/out.sql @@ -1,12 +1,15 @@ WITH `bfcte_0` AS ( SELECT - `bool_col` + `bool_col`, + `int64_col` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT - COALESCE(LOGICAL_OR(`bool_col`), FALSE) AS `bfcol_1` + COALESCE(LOGICAL_OR(`bool_col`), FALSE) AS `bfcol_2`, + COALESCE(LOGICAL_OR(`int64_col` <> 0), FALSE) AS `bfcol_3` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `bool_col` + `bfcol_2` AS `bool_col`, + `bfcol_3` AS `int64_col` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_any/window_out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_any_w_window/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_any/window_out.sql rename to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_any_w_window/out.sql diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff_w_date/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff_w_date/out.sql new file mode 100644 index 0000000000..4f1729d2e2 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_diff_w_date/out.sql @@ -0,0 +1,15 @@ +WITH `bfcte_0` AS ( + SELECT + `date_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CAST(FLOOR( + DATE_DIFF(`date_col`, LAG(`date_col`, 1) OVER (ORDER BY `date_col` ASC NULLS LAST), DAY) * 86400000000 + ) AS INT64) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `diff_date` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/out.sql index bec1527137..94ca21988e 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/out.sql @@ -7,7 +7,7 @@ WITH `bfcte_0` AS ( CASE WHEN LOGICAL_OR(`int64_col` = 0) THEN 0 - ELSE EXP(SUM(CASE WHEN `int64_col` = 0 THEN 0 ELSE LN(ABS(`int64_col`)) END)) * IF(MOD(SUM(CASE WHEN SIGN(`int64_col`) < 0 THEN 1 ELSE 0 END), 2) = 1, -1, 1) + ELSE POWER(2, SUM(IF(`int64_col` = 0, 0, LOG(ABS(`int64_col`), 2)))) * POWER(-1, MOD(SUM(CASE WHEN SIGN(`int64_col`) = -1 THEN 1 ELSE 0 END), 2)) END AS `bfcol_1` FROM `bfcte_0` ) diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/window_partition_out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/window_partition_out.sql index 9c1650222a..c5f12f7009 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/window_partition_out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_product/window_partition_out.sql @@ -9,15 +9,15 @@ WITH `bfcte_0` AS ( CASE WHEN LOGICAL_OR(`int64_col` = 0) OVER (PARTITION BY `string_col`) THEN 0 - ELSE EXP( - SUM(CASE WHEN `int64_col` = 0 THEN 0 ELSE LN(ABS(`int64_col`)) END) OVER (PARTITION BY `string_col`) - ) * IF( + ELSE POWER( + 2, + SUM(IF(`int64_col` = 0, 0, LOG(ABS(`int64_col`), 2))) OVER (PARTITION BY `string_col`) + ) * POWER( + -1, MOD( - SUM(CASE WHEN SIGN(`int64_col`) < 0 THEN 1 ELSE 0 END) OVER (PARTITION BY `string_col`), + SUM(CASE WHEN SIGN(`int64_col`) = -1 THEN 1 ELSE 0 END) OVER (PARTITION BY `string_col`), 2 - ) = 1, - -1, - 1 + ) ) END AS `bfcol_2` FROM `bfcte_0` diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_quantile/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_quantile/out.sql index b79d8d381f..e337356d96 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_quantile/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_quantile/out.sql @@ -1,14 +1,17 @@ WITH `bfcte_0` AS ( SELECT + `bool_col`, `int64_col` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT - PERCENTILE_CONT(`int64_col`, 0.5) OVER () AS `bfcol_1`, - CAST(FLOOR(PERCENTILE_CONT(`int64_col`, 0.5) OVER ()) AS INT64) AS `bfcol_2` + PERCENTILE_CONT(`int64_col`, 0.5) OVER () AS `bfcol_4`, + PERCENTILE_CONT(CAST(`bool_col` AS INT64), 0.5) OVER () AS `bfcol_5`, + CAST(FLOOR(PERCENTILE_CONT(`int64_col`, 0.5) OVER ()) AS INT64) AS `bfcol_6` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `quantile`, - `bfcol_2` AS `quantile_floor` + `bfcol_4` AS `int64`, + `bfcol_5` AS `bool`, + `bfcol_6` AS `int64_w_floor` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index c15d70478a..d9bfb1f5f3 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -63,41 +63,47 @@ def _apply_unary_window_op( def test_all(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["bool_col", "int64_col"]] + ops_map = { + "bool_col": agg_ops.AllOp().as_expr("bool_col"), + "int64_col": agg_ops.AllOp().as_expr("int64_col"), + } + sql = _apply_unary_agg_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + + snapshot.assert_match(sql, "out.sql") + + +def test_all_w_window(scalar_types_df: bpd.DataFrame, snapshot): col_name = "bool_col" bf_df = scalar_types_df[[col_name]] agg_expr = agg_ops.AllOp().as_expr(col_name) - sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name]) - - snapshot.assert_match(sql, "out.sql") # Window tests window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) sql_window = _apply_unary_window_op(bf_df, agg_expr, window, "agg_bool") - snapshot.assert_match(sql_window, "window_out.sql") - - bf_df_str = scalar_types_df[[col_name, "string_col"]] - window_partition = window_spec.WindowSpec( - grouping_keys=(expression.deref("string_col"),), - ordering=(ordering.descending_over(col_name),), - ) - sql_window_partition = _apply_unary_window_op( - bf_df_str, agg_expr, window_partition, "agg_bool" - ) - snapshot.assert_match(sql_window_partition, "window_partition_out.sql") + snapshot.assert_match(sql_window, "out.sql") def test_any(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["bool_col", "int64_col"]] + ops_map = { + "bool_col": agg_ops.AnyOp().as_expr("bool_col"), + "int64_col": agg_ops.AnyOp().as_expr("int64_col"), + } + sql = _apply_unary_agg_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + + snapshot.assert_match(sql, "out.sql") + + +def test_any_w_window(scalar_types_df: bpd.DataFrame, snapshot): col_name = "bool_col" bf_df = scalar_types_df[[col_name]] agg_expr = agg_ops.AnyOp().as_expr(col_name) - sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name]) - - snapshot.assert_match(sql, "out.sql") # Window tests window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) sql_window = _apply_unary_window_op(bf_df, agg_expr, window, "agg_bool") - snapshot.assert_match(sql_window, "window_out.sql") + snapshot.assert_match(sql_window, "out.sql") def test_approx_quartiles(scalar_types_df: bpd.DataFrame, snapshot): @@ -247,6 +253,17 @@ def test_diff_w_datetime(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_diff_w_date(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "date_col" + bf_df_date = scalar_types_df[[col_name]] + window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),)) + op = agg_exprs.UnaryAggregation( + agg_ops.DiffOp(periods=1), expression.deref(col_name) + ) + sql = _apply_unary_window_op(bf_df_date, op, window, "diff_date") + snapshot.assert_match(sql, "out.sql") + + def test_diff_w_timestamp(scalar_types_df: bpd.DataFrame, snapshot): col_name = "timestamp_col" bf_df_timestamp = scalar_types_df[[col_name]] @@ -474,12 +491,12 @@ def test_qcut(scalar_types_df: bpd.DataFrame, snapshot): def test_quantile(scalar_types_df: bpd.DataFrame, snapshot): - col_name = "int64_col" - bf_df = scalar_types_df[[col_name]] + bf_df = scalar_types_df[["int64_col", "bool_col"]] agg_ops_map = { - "quantile": agg_ops.QuantileOp(q=0.5).as_expr(col_name), - "quantile_floor": agg_ops.QuantileOp(q=0.5, should_floor_result=True).as_expr( - col_name + "int64": agg_ops.QuantileOp(q=0.5).as_expr("int64_col"), + "bool": agg_ops.QuantileOp(q=0.5).as_expr("bool_col"), + "int64_w_floor": agg_ops.QuantileOp(q=0.5, should_floor_result=True).as_expr( + "int64_col" ), } sql = _apply_unary_agg_ops(