Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 39 additions & 19 deletions bigframes/core/compile/sqlglot/aggregations/unary_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from bigframes.core import window_spec
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
from bigframes.core.compile.sqlglot.expressions import constants
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
from bigframes.operations import aggregations as agg_ops
Expand All @@ -44,9 +45,13 @@ def _(
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
# BQ will return null for empty column, result would be false in pandas.
result = apply_window_if_present(sge.func("LOGICAL_AND", column.expr), window)
return sge.func("IFNULL", result, sge.true())
expr = column.expr
if column.dtype != dtypes.BOOL_DTYPE:
expr = sge.NEQ(this=expr, expression=sge.convert(0))
expr = apply_window_if_present(sge.func("LOGICAL_AND", expr), window)

# BQ will return null for empty column, result would be true in pandas.
return sge.func("COALESCE", expr, sge.convert(True))


@UNARY_OP_REGISTRATION.register(agg_ops.AnyOp)
Expand All @@ -56,6 +61,8 @@ def _(
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
expr = column.expr
if column.dtype != dtypes.BOOL_DTYPE:
expr = sge.NEQ(this=expr, expression=sge.convert(0))
expr = apply_window_if_present(sge.func("LOGICAL_OR", expr), window)

# BQ will return null for empty column, result would be false in pandas.
Expand Down Expand Up @@ -326,6 +333,15 @@ def _(
unit=sge.Identifier(this="MICROSECOND"),
)

if column.dtype == dtypes.DATE_DTYPE:
date_diff = sge.DateDiff(
this=column.expr, expression=shifted, unit=sge.Identifier(this="DAY")
)
return sge.Cast(
this=sge.Floor(this=date_diff * constants._DAY_TO_MICROSECONDS),
to="INT64",
)

raise TypeError(f"Cannot perform diff on type {column.dtype}")


Expand Down Expand Up @@ -410,24 +426,28 @@ def _(
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
expr = column.expr
if column.dtype == dtypes.BOOL_DTYPE:
expr = sge.Cast(this=expr, to="INT64")

# Need to short-circuit as log with zeroes is illegal sql
is_zero = sge.EQ(this=column.expr, expression=sge.convert(0))
is_zero = sge.EQ(this=expr, expression=sge.convert(0))

# There is no product sql aggregate function, so must implement as a sum of logs, and then
# apply power after. Note, log and power base must be equal! This impl uses natural log.
logs = (
sge.Case()
.when(is_zero, sge.convert(0))
.else_(sge.func("LN", sge.func("ABS", column.expr)))
logs = sge.If(
this=is_zero,
true=sge.convert(0),
false=sge.func("LOG", sge.convert(2), sge.func("ABS", expr)),
)
logs_sum = apply_window_if_present(sge.func("SUM", logs), window)
magnitude = sge.func("EXP", logs_sum)
magnitude = sge.func("POWER", sge.convert(2), logs_sum)

# Can't determine sign from logs, so have to determine parity of count of negative inputs
is_negative = (
sge.Case()
.when(
sge.LT(this=sge.func("SIGN", column.expr), expression=sge.convert(0)),
sge.EQ(this=sge.func("SIGN", expr), expression=sge.convert(-1)),
sge.convert(1),
)
.else_(sge.convert(0))
Expand All @@ -445,11 +465,7 @@ def _(
.else_(
sge.Mul(
this=magnitude,
expression=sge.If(
this=sge.EQ(this=negative_count_parity, expression=sge.convert(1)),
true=sge.convert(-1),
false=sge.convert(1),
),
expression=sge.func("POWER", sge.convert(-1), negative_count_parity),
)
)
)
Expand Down Expand Up @@ -499,14 +515,18 @@ def _(
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
# TODO: Support interpolation argument
# TODO: Support percentile_disc
result: sge.Expression = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q))
expr = column.expr
if column.dtype == dtypes.BOOL_DTYPE:
expr = sge.Cast(this=expr, to="INT64")

result: sge.Expression = sge.func("PERCENTILE_CONT", expr, sge.convert(op.q))
if window is None:
# PERCENTILE_CONT is a navigation function, not an aggregate function, so it always needs an OVER clause.
# PERCENTILE_CONT is a navigation function, not an aggregate function,
# so it always needs an OVER clause.
result = sge.Window(this=result)
else:
result = apply_window_if_present(result, window)

if op.should_floor_result:
result = sge.Cast(this=sge.func("FLOOR", result), to="INT64")
return result
Expand Down
1 change: 1 addition & 0 deletions bigframes/core/compile/sqlglot/expressions/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_NAN = sge.Cast(this=sge.convert("NaN"), to="FLOAT64")
_INF = sge.Cast(this=sge.convert("Infinity"), to="FLOAT64")
_NEG_INF = sge.Cast(this=sge.convert("-Infinity"), to="FLOAT64")
_DAY_TO_MICROSECONDS = sge.convert(86400000000)

# Approx Highest number you can pass in to EXP function and get a valid FLOAT64 result
# FLOAT64 has 11 exponent bits, so max values is about 2**(2**10)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
WITH `bfcte_0` AS (
SELECT
`bool_col`
`bool_col`,
`int64_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
COALESCE(LOGICAL_AND(`bool_col`), TRUE) AS `bfcol_1`
COALESCE(LOGICAL_AND(`bool_col`), TRUE) AS `bfcol_2`,
COALESCE(LOGICAL_AND(`int64_col` <> 0), TRUE) AS `bfcol_3`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `bool_col`
`bfcol_2` AS `bool_col`,
`bfcol_3` AS `int64_col`
FROM `bfcte_1`

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
WITH `bfcte_0` AS (
SELECT
`bool_col`
`bool_col`,
`int64_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
COALESCE(LOGICAL_OR(`bool_col`), FALSE) AS `bfcol_1`
COALESCE(LOGICAL_OR(`bool_col`), FALSE) AS `bfcol_2`,
COALESCE(LOGICAL_OR(`int64_col` <> 0), FALSE) AS `bfcol_3`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `bool_col`
`bfcol_2` AS `bool_col`,
`bfcol_3` AS `int64_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
WITH `bfcte_0` AS (
SELECT
`date_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CAST(FLOOR(
DATE_DIFF(`date_col`, LAG(`date_col`, 1) OVER (ORDER BY `date_col` ASC NULLS LAST), DAY) * 86400000000
) AS INT64) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `diff_date`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
CASE
WHEN LOGICAL_OR(`int64_col` = 0)
THEN 0
ELSE EXP(SUM(CASE WHEN `int64_col` = 0 THEN 0 ELSE LN(ABS(`int64_col`)) END)) * IF(MOD(SUM(CASE WHEN SIGN(`int64_col`) < 0 THEN 1 ELSE 0 END), 2) = 1, -1, 1)
ELSE POWER(2, SUM(IF(`int64_col` = 0, 0, LOG(ABS(`int64_col`), 2)))) * POWER(-1, MOD(SUM(CASE WHEN SIGN(`int64_col`) = -1 THEN 1 ELSE 0 END), 2))
END AS `bfcol_1`
FROM `bfcte_0`
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ WITH `bfcte_0` AS (
CASE
WHEN LOGICAL_OR(`int64_col` = 0) OVER (PARTITION BY `string_col`)
THEN 0
ELSE EXP(
SUM(CASE WHEN `int64_col` = 0 THEN 0 ELSE LN(ABS(`int64_col`)) END) OVER (PARTITION BY `string_col`)
) * IF(
ELSE POWER(
2,
SUM(IF(`int64_col` = 0, 0, LOG(ABS(`int64_col`), 2))) OVER (PARTITION BY `string_col`)
) * POWER(
-1,
MOD(
SUM(CASE WHEN SIGN(`int64_col`) < 0 THEN 1 ELSE 0 END) OVER (PARTITION BY `string_col`),
SUM(CASE WHEN SIGN(`int64_col`) = -1 THEN 1 ELSE 0 END) OVER (PARTITION BY `string_col`),
2
) = 1,
-1,
1
)
)
END AS `bfcol_2`
FROM `bfcte_0`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
WITH `bfcte_0` AS (
SELECT
`bool_col`,
`int64_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
PERCENTILE_CONT(`int64_col`, 0.5) OVER () AS `bfcol_1`,
CAST(FLOOR(PERCENTILE_CONT(`int64_col`, 0.5) OVER ()) AS INT64) AS `bfcol_2`
PERCENTILE_CONT(`int64_col`, 0.5) OVER () AS `bfcol_4`,
PERCENTILE_CONT(CAST(`bool_col` AS INT64), 0.5) OVER () AS `bfcol_5`,
CAST(FLOOR(PERCENTILE_CONT(`int64_col`, 0.5) OVER ()) AS INT64) AS `bfcol_6`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `quantile`,
`bfcol_2` AS `quantile_floor`
`bfcol_4` AS `int64`,
`bfcol_5` AS `bool`,
`bfcol_6` AS `int64_w_floor`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
Expand Up @@ -63,41 +63,47 @@ def _apply_unary_window_op(


def test_all(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["bool_col", "int64_col"]]
ops_map = {
"bool_col": agg_ops.AllOp().as_expr("bool_col"),
"int64_col": agg_ops.AllOp().as_expr("int64_col"),
}
sql = _apply_unary_agg_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))

snapshot.assert_match(sql, "out.sql")


def test_all_w_window(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "bool_col"
bf_df = scalar_types_df[[col_name]]
agg_expr = agg_ops.AllOp().as_expr(col_name)
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])

snapshot.assert_match(sql, "out.sql")

# Window tests
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
sql_window = _apply_unary_window_op(bf_df, agg_expr, window, "agg_bool")
snapshot.assert_match(sql_window, "window_out.sql")

bf_df_str = scalar_types_df[[col_name, "string_col"]]
window_partition = window_spec.WindowSpec(
grouping_keys=(expression.deref("string_col"),),
ordering=(ordering.descending_over(col_name),),
)
sql_window_partition = _apply_unary_window_op(
bf_df_str, agg_expr, window_partition, "agg_bool"
)
snapshot.assert_match(sql_window_partition, "window_partition_out.sql")
snapshot.assert_match(sql_window, "out.sql")


def test_any(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["bool_col", "int64_col"]]
ops_map = {
"bool_col": agg_ops.AnyOp().as_expr("bool_col"),
"int64_col": agg_ops.AnyOp().as_expr("int64_col"),
}
sql = _apply_unary_agg_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))

snapshot.assert_match(sql, "out.sql")


def test_any_w_window(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "bool_col"
bf_df = scalar_types_df[[col_name]]
agg_expr = agg_ops.AnyOp().as_expr(col_name)
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])

snapshot.assert_match(sql, "out.sql")

# Window tests
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
sql_window = _apply_unary_window_op(bf_df, agg_expr, window, "agg_bool")
snapshot.assert_match(sql_window, "window_out.sql")
snapshot.assert_match(sql_window, "out.sql")


def test_approx_quartiles(scalar_types_df: bpd.DataFrame, snapshot):
Expand Down Expand Up @@ -247,6 +253,17 @@ def test_diff_w_datetime(scalar_types_df: bpd.DataFrame, snapshot):
snapshot.assert_match(sql, "out.sql")


def test_diff_w_date(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "date_col"
bf_df_date = scalar_types_df[[col_name]]
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
op = agg_exprs.UnaryAggregation(
agg_ops.DiffOp(periods=1), expression.deref(col_name)
)
sql = _apply_unary_window_op(bf_df_date, op, window, "diff_date")
snapshot.assert_match(sql, "out.sql")


def test_diff_w_timestamp(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "timestamp_col"
bf_df_timestamp = scalar_types_df[[col_name]]
Expand Down Expand Up @@ -474,12 +491,12 @@ def test_qcut(scalar_types_df: bpd.DataFrame, snapshot):


def test_quantile(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
bf_df = scalar_types_df[["int64_col", "bool_col"]]
agg_ops_map = {
"quantile": agg_ops.QuantileOp(q=0.5).as_expr(col_name),
"quantile_floor": agg_ops.QuantileOp(q=0.5, should_floor_result=True).as_expr(
col_name
"int64": agg_ops.QuantileOp(q=0.5).as_expr("int64_col"),
"bool": agg_ops.QuantileOp(q=0.5).as_expr("bool_col"),
"int64_w_floor": agg_ops.QuantileOp(q=0.5, should_floor_result=True).as_expr(
"int64_col"
),
}
sql = _apply_unary_agg_ops(
Expand Down
Loading