From e808db896ad8ddcb81f9cba0402d715efcfd5e2c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 24 Apr 2026 08:16:08 -0400 Subject: [PATCH 1/5] tpch examples: add reference SQL to each query, fix Q20 - Append the canonical TPC-H reference SQL (from benchmarks/tpch/queries/) to each q01..q22 module docstring so readers can compare the DataFrame translation against the SQL at a glance. - Fix Q20: `df = df.filter(col("ps_availqty") > lit(0.5) * col("total_sold"))` was missing the assignment so the filter was dropped from the pipeline. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/tpch/q01_pricing_summary_report.py | 24 ++++++++++ examples/tpch/q02_minimum_cost_supplier.py | 46 +++++++++++++++++++ examples/tpch/q03_shipping_priority.py | 25 ++++++++++ examples/tpch/q04_order_priority_checking.py | 24 ++++++++++ examples/tpch/q05_local_supplier_volume.py | 27 +++++++++++ .../tpch/q06_forecasting_revenue_change.py | 12 +++++ examples/tpch/q07_volume_shipping.py | 42 +++++++++++++++++ examples/tpch/q08_market_share.py | 40 ++++++++++++++++ .../tpch/q09_product_type_profit_measure.py | 35 ++++++++++++++ examples/tpch/q10_returned_item_reporting.py | 34 ++++++++++++++ .../q11_important_stock_identification.py | 30 ++++++++++++ examples/tpch/q12_ship_mode_order_priority.py | 31 +++++++++++++ examples/tpch/q13_customer_distribution.py | 23 ++++++++++ examples/tpch/q14_promotion_effect.py | 16 +++++++ examples/tpch/q15_top_supplier.py | 34 ++++++++++++++ .../tpch/q16_part_supplier_relationship.py | 33 +++++++++++++ examples/tpch/q17_small_quantity_order.py | 20 ++++++++ examples/tpch/q18_large_volume_customer.py | 35 ++++++++++++++ examples/tpch/q19_discounted_revenue.py | 38 +++++++++++++++ examples/tpch/q20_potential_part_promotion.py | 42 ++++++++++++++++- .../tpch/q21_suppliers_kept_orders_waiting.py | 42 +++++++++++++++++ examples/tpch/q22_global_sales_opportunity.py | 40 ++++++++++++++++ 22 files changed, 692 insertions(+), 1 deletion(-) diff --git a/examples/tpch/q01_pricing_summary_report.py b/examples/tpch/q01_pricing_summary_report.py index 3f97f00dc..3f963dbe0 100644 --- a/examples/tpch/q01_pricing_summary_report.py +++ b/examples/tpch/q01_pricing_summary_report.py @@ -27,6 +27,30 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order + from + lineitem + where + l_shipdate <= date '1998-12-01' - interval '68 days' + group by + l_returnflag, + l_linestatus + order by + l_returnflag, + l_linestatus; """ import pyarrow as pa diff --git a/examples/tpch/q02_minimum_cost_supplier.py b/examples/tpch/q02_minimum_cost_supplier.py index 47961d2ef..303a02a24 100644 --- a/examples/tpch/q02_minimum_cost_supplier.py +++ b/examples/tpch/q02_minimum_cost_supplier.py @@ -27,6 +27,52 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment + from + part, + supplier, + partsupp, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 48 + and p_type like '%TIN' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + ) + order by + s_acctbal desc, + n_name, + s_name, + p_partkey limit 100; """ import datafusion diff --git a/examples/tpch/q03_shipping_priority.py b/examples/tpch/q03_shipping_priority.py index fc1231e0a..6dc1da42f 100644 --- a/examples/tpch/q03_shipping_priority.py +++ b/examples/tpch/q03_shipping_priority.py @@ -25,6 +25,31 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + o_orderdate, + o_shippriority + from + customer, + orders, + lineitem + where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < date '1995-03-15' + and l_shipdate > date '1995-03-15' + group by + l_orderkey, + o_orderdate, + o_shippriority + order by + revenue desc, + o_orderdate limit 10; """ from datafusion import SessionContext, col, lit diff --git a/examples/tpch/q04_order_priority_checking.py b/examples/tpch/q04_order_priority_checking.py index 426338aea..d40564565 100644 --- a/examples/tpch/q04_order_priority_checking.py +++ b/examples/tpch/q04_order_priority_checking.py @@ -24,6 +24,30 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + o_orderpriority, + count(*) as order_count + from + orders + where + o_orderdate >= date '1995-04-01' + and o_orderdate < date '1995-04-01' + interval '3' month + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) + group by + o_orderpriority + order by + o_orderpriority; """ from datetime import datetime diff --git a/examples/tpch/q05_local_supplier_volume.py b/examples/tpch/q05_local_supplier_volume.py index fa2b01dea..227d5264d 100644 --- a/examples/tpch/q05_local_supplier_volume.py +++ b/examples/tpch/q05_local_supplier_volume.py @@ -27,6 +27,33 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue + from + customer, + orders, + lineitem, + supplier, + nation, + region + where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'AFRICA' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1994-01-01' + interval '1' year + group by + n_name + order by + revenue desc; """ from datetime import datetime diff --git a/examples/tpch/q06_forecasting_revenue_change.py b/examples/tpch/q06_forecasting_revenue_change.py index 1de5848b1..2ac2f3401 100644 --- a/examples/tpch/q06_forecasting_revenue_change.py +++ b/examples/tpch/q06_forecasting_revenue_change.py @@ -27,6 +27,18 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice * l_discount) as revenue + from + lineitem + where + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + and l_discount between 0.04 - 0.01 and 0.04 + 0.01 + and l_quantity < 24; """ from datetime import datetime diff --git a/examples/tpch/q07_volume_shipping.py b/examples/tpch/q07_volume_shipping.py index ff2f891f1..e6b6a5b5e 100644 --- a/examples/tpch/q07_volume_shipping.py +++ b/examples/tpch/q07_volume_shipping.py @@ -26,6 +26,48 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue + from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + extract(year from l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'GERMANY' and n2.n_name = 'IRAQ') + or (n1.n_name = 'IRAQ' and n2.n_name = 'GERMANY') + ) + and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping + group by + supp_nation, + cust_nation, + l_year + order by + supp_nation, + cust_nation, + l_year; """ from datetime import datetime diff --git a/examples/tpch/q08_market_share.py b/examples/tpch/q08_market_share.py index 4bf50efba..0869a8090 100644 --- a/examples/tpch/q08_market_share.py +++ b/examples/tpch/q08_market_share.py @@ -25,6 +25,46 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + o_year, + sum(case + when nation = 'IRAQ' then volume + else 0 + end) / sum(volume) as mkt_share + from + ( + select + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'MIDDLE EAST' + and s_nationkey = n2.n_nationkey + and o_orderdate between date '1995-01-01' and date '1996-12-31' + and p_type = 'LARGE PLATED STEEL' + ) as all_nations + group by + o_year + order by + o_year; """ from datetime import datetime diff --git a/examples/tpch/q09_product_type_profit_measure.py b/examples/tpch/q09_product_type_profit_measure.py index e2abbd095..7264f78ce 100644 --- a/examples/tpch/q09_product_type_profit_measure.py +++ b/examples/tpch/q09_product_type_profit_measure.py @@ -27,6 +27,41 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + nation, + o_year, + sum(amount) as sum_profit + from + ( + select + n_name as nation, + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%moccasin%' + ) as profit + group by + nation, + o_year + order by + nation, + o_year desc; """ import pyarrow as pa diff --git a/examples/tpch/q10_returned_item_reporting.py b/examples/tpch/q10_returned_item_reporting.py index ed822e264..ca25f6a88 100644 --- a/examples/tpch/q10_returned_item_reporting.py +++ b/examples/tpch/q10_returned_item_reporting.py @@ -27,6 +27,40 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment + from + customer, + orders, + lineitem, + nation + where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and l_returnflag = 'R' + and c_nationkey = n_nationkey + group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment + order by + revenue desc limit 20; """ from datetime import datetime diff --git a/examples/tpch/q11_important_stock_identification.py b/examples/tpch/q11_important_stock_identification.py index de309fa64..8b67091b2 100644 --- a/examples/tpch/q11_important_stock_identification.py +++ b/examples/tpch/q11_important_stock_identification.py @@ -25,6 +25,36 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'ALGERIA' + group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001000000 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'ALGERIA' + ) + order by + value desc; """ from datafusion import SessionContext, WindowFrame, col, lit diff --git a/examples/tpch/q12_ship_mode_order_priority.py b/examples/tpch/q12_ship_mode_order_priority.py index 9071597f0..fc6ec8c20 100644 --- a/examples/tpch/q12_ship_mode_order_priority.py +++ b/examples/tpch/q12_ship_mode_order_priority.py @@ -27,6 +27,37 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count + from + orders, + lineitem + where + o_orderkey = l_orderkey + and l_shipmode in ('FOB', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date '1995-01-01' + and l_receiptdate < date '1995-01-01' + interval '1' year + group by + l_shipmode + order by + l_shipmode; """ from datetime import datetime diff --git a/examples/tpch/q13_customer_distribution.py b/examples/tpch/q13_customer_distribution.py index 93f082ea3..df1f0884f 100644 --- a/examples/tpch/q13_customer_distribution.py +++ b/examples/tpch/q13_customer_distribution.py @@ -26,6 +26,29 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_count, + count(*) as custdist + from + ( + select + c_custkey, + count(o_orderkey) + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%express%requests%' + group by + c_custkey + ) as c_orders (c_custkey, c_count) + group by + c_count + order by + custdist desc, + c_count desc; """ from datafusion import SessionContext, col, lit diff --git a/examples/tpch/q14_promotion_effect.py b/examples/tpch/q14_promotion_effect.py index d62f76e3c..a3be0e9b8 100644 --- a/examples/tpch/q14_promotion_effect.py +++ b/examples/tpch/q14_promotion_effect.py @@ -24,6 +24,22 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue + from + lineitem, + part + where + l_partkey = p_partkey + and l_shipdate >= date '1995-02-01' + and l_shipdate < date '1995-02-01' + interval '1' month; """ from datetime import datetime diff --git a/examples/tpch/q15_top_supplier.py b/examples/tpch/q15_top_supplier.py index 5128937a7..285bee497 100644 --- a/examples/tpch/q15_top_supplier.py +++ b/examples/tpch/q15_top_supplier.py @@ -24,6 +24,40 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + create view revenue0 (supplier_no, total_revenue) as + select + l_suppkey, + sum(l_extendedprice * (1 - l_discount)) + from + lineitem + where + l_shipdate >= date '1996-08-01' + and l_shipdate < date '1996-08-01' + interval '3' month + group by + l_suppkey; + select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue + from + supplier, + revenue0 + where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue0 + ) + order by + s_suppkey; + drop view revenue0; """ from datetime import datetime diff --git a/examples/tpch/q16_part_supplier_relationship.py b/examples/tpch/q16_part_supplier_relationship.py index 65043ffda..1875242f5 100644 --- a/examples/tpch/q16_part_supplier_relationship.py +++ b/examples/tpch/q16_part_supplier_relationship.py @@ -26,6 +26,39 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt + from + partsupp, + part + where + p_partkey = ps_partkey + and p_brand <> 'Brand#14' + and p_type not like 'SMALL PLATED%' + and p_size in (14, 6, 5, 31, 49, 15, 41, 47) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' + ) + group by + p_brand, + p_type, + p_size + order by + supplier_cnt desc, + p_brand, + p_type, + p_size; """ import pyarrow as pa diff --git a/examples/tpch/q17_small_quantity_order.py b/examples/tpch/q17_small_quantity_order.py index 5ccb38422..29a8dfbef 100644 --- a/examples/tpch/q17_small_quantity_order.py +++ b/examples/tpch/q17_small_quantity_order.py @@ -26,6 +26,26 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice) / 7.0 as avg_yearly + from + lineitem, + part + where + p_partkey = l_partkey + and p_brand = 'Brand#42' + and p_container = 'LG BAG' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + ); """ from datafusion import SessionContext, WindowFrame, col, lit diff --git a/examples/tpch/q18_large_volume_customer.py b/examples/tpch/q18_large_volume_customer.py index 834d181c9..0caf0ebd6 100644 --- a/examples/tpch/q18_large_volume_customer.py +++ b/examples/tpch/q18_large_volume_customer.py @@ -24,6 +24,41 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) + from + customer, + orders, + lineitem + where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 313 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey + group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice + order by + o_totalprice desc, + o_orderdate limit 100; """ from datafusion import SessionContext, col, lit diff --git a/examples/tpch/q19_discounted_revenue.py b/examples/tpch/q19_discounted_revenue.py index bd492aac0..cd2349df3 100644 --- a/examples/tpch/q19_discounted_revenue.py +++ b/examples/tpch/q19_discounted_revenue.py @@ -24,6 +24,44 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice* (1 - l_discount)) as revenue + from + lineitem, + part + where + ( + p_partkey = l_partkey + and p_brand = 'Brand#21' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 8 and l_quantity <= 8 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#13' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#52' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 30 and l_quantity <= 30 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ); """ import pyarrow as pa diff --git a/examples/tpch/q20_potential_part_promotion.py b/examples/tpch/q20_potential_part_promotion.py index a25188d31..51a2a2ba0 100644 --- a/examples/tpch/q20_potential_part_promotion.py +++ b/examples/tpch/q20_potential_part_promotion.py @@ -25,6 +25,46 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_name, + s_address + from + supplier, + nation + where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'blanched%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1993-01-01' + and l_shipdate < date '1993-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey + and n_name = 'KENYA' + order by + s_name; """ from datetime import datetime @@ -87,7 +127,7 @@ ) # Find cases of excess quantity -df.filter(col("ps_availqty") > lit(0.5) * col("total_sold")) +df = df.filter(col("ps_availqty") > lit(0.5) * col("total_sold")) # We could do these joins earlier, but now limit to the nation of interest suppliers df = df.join(df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner") diff --git a/examples/tpch/q21_suppliers_kept_orders_waiting.py b/examples/tpch/q21_suppliers_kept_orders_waiting.py index 4ee9d3733..7953961a7 100644 --- a/examples/tpch/q21_suppliers_kept_orders_waiting.py +++ b/examples/tpch/q21_suppliers_kept_orders_waiting.py @@ -24,6 +24,48 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_name, + count(*) as numwait + from + supplier, + lineitem l1, + orders, + nation + where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'ARGENTINA' + group by + s_name + order by + numwait desc, + s_name limit 100; """ from datafusion import SessionContext, col, lit diff --git a/examples/tpch/q22_global_sales_opportunity.py b/examples/tpch/q22_global_sales_opportunity.py index a2d41b215..a294fb5d5 100644 --- a/examples/tpch/q22_global_sales_opportunity.py +++ b/examples/tpch/q22_global_sales_opportunity.py @@ -24,6 +24,46 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal + from + ( + select + substring(c_phone from 1 for 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone from 1 for 2) in + ('24', '34', '16', '30', '33', '14', '13') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone from 1 for 2) in + ('24', '34', '16', '30', '33', '14', '13') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale + group by + cntrycode + order by + cntrycode; """ from datafusion import SessionContext, WindowFrame, col, lit From 91f96cb689b8ba361d1a2b70d2399205ca4bde11 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 24 Apr 2026 08:36:47 -0400 Subject: [PATCH 2/5] tpch examples: rewrite non-idiomatic queries in idiomatic DataFrame form Rewrite the seven TPC-H example queries that did not demonstrate the idiomatic DataFrame pattern. The remaining queries (Q02/Q11/Q15/Q17/Q22, which use window functions in place of correlated subqueries) already are idiomatic and are left unchanged. - Q04: replace `.aggregate([col("l_orderkey")], [])` with `.select("l_orderkey").distinct()`, which is the natural way to express "reduce to one row per order" on a DataFrame. - Q07: remove the CASE-as-filter on `n_name` and use `F.in_list(col("n_name"), [nation_1, nation_2])` instead. Drops a comment block that admitted the filter form was simpler. - Q08: rewrite the switched CASE `F.case(...).when(lit(False), ...)` as a searched `F.when(col(...).is_not_null(), ...).otherwise(...)`. That mirrors the reference SQL's `case when ... then ... else 0 end` shape. - Q12: replace `array_position(make_array(...), col)` with `F.in_list(col("l_shipmode"), [...])`. Same semantics, without routing through array construction / array search. - Q19: remove the pyarrow UDF that re-implemented a disjunctive predicate in Python. Build the same predicate in DataFusion by OR-combining one `in_list` + range-filter expression per brand. Keeps the per-brand constants in the existing `items_of_interest` dict. - Q20: use `F.starts_with` instead of an explicit substring slice. Replace the inner-join + `select(...).distinct()` tail with a semi join against a precomputed set of excess-quantity suppliers so the supplier columns are preserved without deduplication after the fact. - Q21: replace the `array_agg` / `array_length` / `array_element` pipeline with two semi joins. One semi join keeps orders with more than one distinct supplier (stand-in for the reference SQL's `exists` subquery), the other keeps orders with exactly one late supplier (stand-in for the `not exists` subquery). All 22 answer-file comparisons and 22 plan-comparison diagnostics still pass (`pytest examples/tpch/_tests.py`: 44 passed). Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/tpch/q04_order_priority_checking.py | 14 +-- examples/tpch/q07_volume_shipping.py | 16 +-- examples/tpch/q08_market_share.py | 9 +- examples/tpch/q12_ship_mode_order_priority.py | 17 +-- examples/tpch/q19_discounted_revenue.py | 80 +++++--------- examples/tpch/q20_potential_part_promotion.py | 66 ++++++------ .../tpch/q21_suppliers_kept_orders_waiting.py | 101 +++++++++--------- 7 files changed, 131 insertions(+), 172 deletions(-) diff --git a/examples/tpch/q04_order_priority_checking.py b/examples/tpch/q04_order_priority_checking.py index d40564565..f2ea1f5c9 100644 --- a/examples/tpch/q04_order_priority_checking.py +++ b/examples/tpch/q04_order_priority_checking.py @@ -77,13 +77,13 @@ interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) -# Limit results to cases where commitment date before receipt date -# Aggregate the results so we only get one row to join with the order table. -# Alternately, and likely more idiomatic is instead of `.aggregate` you could -# do `.select("l_orderkey").distinct()`. The goal here is to show -# multiple examples of how to use Data Fusion. -df_lineitem = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")).aggregate( - [col("l_orderkey")], [] +# Limit results to cases where commitment date before receipt date, then +# reduce to a single row per order so the join with the orders table is a +# semantic EXISTS rather than a fan-out. +df_lineitem = ( + df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")) + .select("l_orderkey") + .distinct() ) # Limit orders to date range of interest diff --git a/examples/tpch/q07_volume_shipping.py b/examples/tpch/q07_volume_shipping.py index e6b6a5b5e..0823369a8 100644 --- a/examples/tpch/q07_volume_shipping.py +++ b/examples/tpch/q07_volume_shipping.py @@ -116,20 +116,8 @@ ) -# A simpler way to do the following operation is to use a filter, but we also want to demonstrate -# how to use case statements. Here we are assigning `n_name` to be itself when it is either of -# the two nations of interest. Since there is no `otherwise()` statement, any values that do -# not match these will result in a null value and then get filtered out. -# -# To do the same using a simple filter would be: -# df_nation = df_nation.filter((F.col("n_name") == nation_1) | (F.col("n_name") == nation_2)) # noqa: ERA001 -df_nation = df_nation.with_column( - "n_name", - F.case(col("n_name")) - .when(nation_1, col("n_name")) - .when(nation_2, col("n_name")) - .end(), -).filter(~col("n_name").is_null()) +# Limit the nation table to the two nations of interest. +df_nation = df_nation.filter(F.in_list(col("n_name"), [nation_1, nation_2])) # Limit suppliers to either nation diff --git a/examples/tpch/q08_market_share.py b/examples/tpch/q08_market_share.py index 0869a8090..d2e034f41 100644 --- a/examples/tpch/q08_market_share.py +++ b/examples/tpch/q08_market_share.py @@ -186,12 +186,13 @@ df_national_suppliers, left_on=["l_suppkey"], right_on=["s_suppkey"], how="left" ) -# Use a case statement to compute the volume sold by suppliers in the nation of interest +# Use a searched CASE (``F.when(...).otherwise(...)``) to keep only the +# volume attributable to suppliers in the nation of interest. This mirrors +# the ``case when nation = '...' then volume else 0 end`` form of the +# reference SQL rather than dispatching on a boolean subject. df = df.with_column( "national_volume", - F.case(col("s_suppkey").is_null()) - .when(lit(value=False), col("volume")) - .otherwise(lit(0.0)), + F.when(col("s_suppkey").is_not_null(), col("volume")).otherwise(lit(0.0)), ) df = df.with_column( diff --git a/examples/tpch/q12_ship_mode_order_priority.py b/examples/tpch/q12_ship_mode_order_priority.py index fc6ec8c20..b3f4c7034 100644 --- a/examples/tpch/q12_ship_mode_order_priority.py +++ b/examples/tpch/q12_ship_mode_order_priority.py @@ -91,20 +91,9 @@ col("l_receiptdate") < lit(date) + lit(interval) ) -# Note: It is not recommended to use array_has because it treats the second argument as an argument -# so if you pass it col("l_shipmode") it will pass the entire array to process which is very slow. -# Instead check the position of the entry is not null. -df = df.filter( - ~F.array_position( - F.make_array(lit(SHIP_MODE_1), lit(SHIP_MODE_2)), col("l_shipmode") - ).is_null() -) - -# Since we have only two values, it's much easier to do this as a filter where the l_shipmode -# matches either of the two values, but we want to show doing some array operations in this -# example. If you want to see this done with filters, comment out the above line and uncomment -# this one. -# df = df.filter((col("l_shipmode") == lit(SHIP_MODE_1)) | (col("l_shipmode") == lit(SHIP_MODE_2))) # noqa: ERA001 +# Restrict to the two ship modes of interest. ``in_list`` maps directly to +# the ``l_shipmode in ('FOB', 'SHIP')`` clause of the reference SQL. +df = df.filter(F.in_list(col("l_shipmode"), [lit(SHIP_MODE_1), lit(SHIP_MODE_2)])) # We need order priority, so join order df to line item diff --git a/examples/tpch/q19_discounted_revenue.py b/examples/tpch/q19_discounted_revenue.py index cd2349df3..69e732c9c 100644 --- a/examples/tpch/q19_discounted_revenue.py +++ b/examples/tpch/q19_discounted_revenue.py @@ -64,8 +64,7 @@ ); """ -import pyarrow as pa -from datafusion import SessionContext, col, lit, udf +from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path @@ -114,59 +113,34 @@ df = df.join(df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner") -# Create the user defined function (UDF) definition that does the work -def is_of_interest( - brand_arr: pa.Array, - container_arr: pa.Array, - quantity_arr: pa.Array, - size_arr: pa.Array, -) -> pa.Array: - """ - The purpose of this function is to demonstrate how a UDF works, taking as input a pyarrow Array - and generating a resultant Array. The length of the inputs should match and there should be the - same number of rows in the output. - """ - result = [] - for idx, brand_val in enumerate(brand_arr): - brand = brand_val.as_py() - if brand in items_of_interest: - values_of_interest = items_of_interest[brand] - - container_matches = ( - container_arr[idx].as_py() in values_of_interest["containers"] - ) - - quantity = quantity_arr[idx].as_py() - quantity_matches = ( - values_of_interest["min_quantity"] - <= quantity - <= values_of_interest["min_quantity"] + 10 - ) - - size = size_arr[idx].as_py() - size_matches = 1 <= size <= values_of_interest["max_size"] - - result.append(container_matches and quantity_matches and size_matches) - else: - result.append(False) - - return pa.array(result) - - -# Turn the above function into a UDF that DataFusion can understand -is_of_interest_udf = udf( - is_of_interest, - [pa.utf8(), pa.utf8(), pa.decimal128(15, 2), pa.int32()], - pa.bool_(), - "stable", -) +# Build one OR-combined predicate per brand. Each disjunct encodes the +# brand-specific container list, quantity window, and size range from the +# reference SQL. This mirrors the SQL ``where (... brand A ...) or (... brand +# B ...) or (... brand C ...)`` form directly, without a UDF. +def _brand_predicate( + brand: str, min_quantity: int, containers: list[str], max_size: int +): + return ( + (col("p_brand") == lit(brand)) + & F.in_list(col("p_container"), [lit(c) for c in containers]) + & (col("l_quantity") >= lit(min_quantity)) + & (col("l_quantity") <= lit(min_quantity + 10)) + & (col("p_size") >= lit(1)) + & (col("p_size") <= lit(max_size)) + ) -# Filter results using the above UDF -df = df.filter( - is_of_interest_udf( - col("p_brand"), col("p_container"), col("l_quantity"), col("p_size") + +predicate = None +for brand, params in items_of_interest.items(): + part_predicate = _brand_predicate( + brand, + params["min_quantity"], + params["containers"], + params["max_size"], ) -) + predicate = part_predicate if predicate is None else predicate | part_predicate + +df = df.filter(predicate) df = df.aggregate( [], diff --git a/examples/tpch/q20_potential_part_promotion.py b/examples/tpch/q20_potential_part_promotion.py index 51a2a2ba0..625943b96 100644 --- a/examples/tpch/q20_potential_part_promotion.py +++ b/examples/tpch/q20_potential_part_promotion.py @@ -100,42 +100,46 @@ interval = pa.scalar((0, 365, 0), type=pa.month_day_nano_interval()) -# Filter down dataframes +# Filter down dataframes. ``starts_with`` reads more naturally than an +# explicit substring slice and maps directly to the reference SQL's +# ``p_name like 'forest%'`` clause. df_nation = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST)) -df_part = df_part.filter( - F.substring(col("p_name"), lit(0), lit(len(COLOR_OF_INTEREST) + 1)) - == lit(COLOR_OF_INTEREST) +df_part = df_part.filter(F.starts_with(col("p_name"), lit(COLOR_OF_INTEREST))) + +# Compute the total quantity of interesting parts shipped by each (part, +# supplier) pair within the year of interest. +totals = ( + df_lineitem.filter(col("l_shipdate") >= lit(date)) + .filter(col("l_shipdate") < lit(date) + lit(interval)) + .join(df_part, left_on="l_partkey", right_on="p_partkey", how="inner") + .aggregate( + [col("l_partkey"), col("l_suppkey")], + [F.sum(col("l_quantity")).alias("total_sold")], + ) ) -df = df_lineitem.filter(col("l_shipdate") >= lit(date)).filter( - col("l_shipdate") < lit(date) + lit(interval) +# Keep only (part, supplier) pairs whose available quantity exceeds 50% of +# the total shipped. The result already contains one row per supplier of +# interest, so we can semi-join the supplier table rather than inner-join +# and deduplicate afterwards. +excess_suppliers = ( + df_partsupp.join( + totals, + left_on=["ps_partkey", "ps_suppkey"], + right_on=["l_partkey", "l_suppkey"], + how="inner", + ) + .filter(col("ps_availqty") > lit(0.5) * col("total_sold")) + .select(col("ps_suppkey").alias("suppkey")) + .distinct() ) -# This will filter down the line items to the parts of interest -df = df.join(df_part, left_on="l_partkey", right_on="p_partkey", how="inner") +# Limit to suppliers in the nation of interest and pick out the two +# requested columns. +df = df_supplier.join( + df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" +).join(excess_suppliers, left_on="s_suppkey", right_on="suppkey", how="semi") -# Compute the total sold and limit ourselves to individual supplier/part combinations -df = df.aggregate( - [col("l_partkey"), col("l_suppkey")], [F.sum(col("l_quantity")).alias("total_sold")] -) - -df = df.join( - df_partsupp, - left_on=["l_partkey", "l_suppkey"], - right_on=["ps_partkey", "ps_suppkey"], - how="inner", -) - -# Find cases of excess quantity -df = df.filter(col("ps_availqty") > lit(0.5) * col("total_sold")) - -# We could do these joins earlier, but now limit to the nation of interest suppliers -df = df.join(df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner") -df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") - -# Restrict to the requested data per the problem statement -df = df.select("s_name", "s_address").distinct() - -df = df.sort(col("s_name").sort()) +df = df.select("s_name", "s_address").sort(col("s_name").sort()) df.show() diff --git a/examples/tpch/q21_suppliers_kept_orders_waiting.py b/examples/tpch/q21_suppliers_kept_orders_waiting.py index 7953961a7..b2113187b 100644 --- a/examples/tpch/q21_suppliers_kept_orders_waiting.py +++ b/examples/tpch/q21_suppliers_kept_orders_waiting.py @@ -92,65 +92,68 @@ ) # Limit to suppliers in the nation of interest -df_suppliers_of_interest = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST)) - -df_suppliers_of_interest = df_suppliers_of_interest.join( - df_supplier, left_on="n_nationkey", right_on="s_nationkey", how="inner" +df_suppliers_of_interest = df_nation.filter( + col("n_name") == lit(NATION_OF_INTEREST) +).join(df_supplier, left_on="n_nationkey", right_on="s_nationkey", how="inner") + +# Line items for orders that have status 'F'. This is the candidate set of +# (order, supplier) pairs we reason about below. +failed_order_lineitems = df_lineitem.join( + df_orders.filter(col("o_orderstatus") == lit("F")), + left_on="l_orderkey", + right_on="o_orderkey", + how="inner", ) -# Find the failed orders and all their line items -df = df_orders.filter(col("o_orderstatus") == lit("F")) - -df = df_lineitem.join(df, left_on="l_orderkey", right_on="o_orderkey", how="inner") - -# Identify the line items for which the order is failed due to. -df = df.with_column( - "failed_supp", - F.case(col("l_receiptdate") > col("l_commitdate")) - .when(lit(value=True), col("l_suppkey")) - .end(), +# Line items whose receipt was late. This corresponds to ``l1`` in the +# reference SQL. +late_lineitems = failed_order_lineitems.filter( + col("l_receiptdate") > col("l_commitdate") ) -# There are other ways we could do this but the purpose of this example is to work with rows where -# an element is an array of values. In this case, we will create two columns of arrays. One will be -# an array of all of the suppliers who made up this order. That way we can filter the dataframe for -# only orders where this array is larger than one for multiple supplier orders. The second column -# is all of the suppliers who failed to make their commitment. We can filter the second column for -# arrays with size one. That combination will give us orders that had multiple suppliers where only -# one failed. Use distinct=True in the blow aggregation so we don't get multiple line items from the -# same supplier reported in either array. -df = df.aggregate( - [col("o_orderkey")], - [ - F.array_agg(col("l_suppkey"), distinct=True).alias("all_suppliers"), - F.array_agg( - col("failed_supp"), filter=col("failed_supp").is_not_null(), distinct=True - ).alias("failed_suppliers"), - ], +# Orders that had more than one distinct supplier. Expressed as +# ``count(distinct l_suppkey) > 1``. Stands in for the reference SQL's +# ``exists (... l2.l_suppkey <> l1.l_suppkey ...)`` subquery. +multi_supplier_orders = ( + failed_order_lineitems.select("l_orderkey", "l_suppkey") + .distinct() + .aggregate([col("l_orderkey")], [F.count(col("l_suppkey")).alias("n_suppliers")]) + .filter(col("n_suppliers") > lit(1)) + .select("l_orderkey") ) -# This is the check described above which will identify single failed supplier in a multiple -# supplier order. -df = df.filter(F.array_length(col("failed_suppliers")) == lit(1)).filter( - F.array_length(col("all_suppliers")) > lit(1) +# Orders where exactly one distinct supplier was late. Stands in for the +# reference SQL's ``not exists (... l3.l_suppkey <> l1.l_suppkey and l3 is +# also late ...)`` subquery: if only one supplier on the order was late, +# nobody else on the same order was late. +single_late_supplier_orders = ( + late_lineitems.select("l_orderkey", "l_suppkey") + .distinct() + .aggregate( + [col("l_orderkey")], [F.count(col("l_suppkey")).alias("n_late_suppliers")] + ) + .filter(col("n_late_suppliers") == lit(1)) + .select("l_orderkey") ) -# Since we have an array we know is exactly one element long, we can extract that single value. -df = df.select( - col("o_orderkey"), F.array_element(col("failed_suppliers"), lit(1)).alias("suppkey") +# Keep late line items whose order qualifies on both counts. Semi joins +# preserve the left-side columns without fanning out on the right. +df = late_lineitems.join(multi_supplier_orders, on="l_orderkey", how="semi").join( + single_late_supplier_orders, on="l_orderkey", how="semi" ) -# Join to the supplier of interest list for the nation of interest -df = df.join( - df_suppliers_of_interest, left_on=["suppkey"], right_on=["s_suppkey"], how="inner" +# Attach the supplier name for suppliers in the nation of interest, count +# one row per qualifying order, and return the top 100. +df = ( + df.join( + df_suppliers_of_interest, + left_on="l_suppkey", + right_on="s_suppkey", + how="inner", + ) + .aggregate([col("s_name")], [F.count(col("l_orderkey")).alias("numwait")]) + .sort(col("numwait").sort(ascending=False), col("s_name").sort()) + .limit(100) ) -# Count how many orders that supplier is the only failed supplier for -df = df.aggregate([col("s_name")], [F.count(col("o_orderkey")).alias("numwait")]) - -# Return in descending order -df = df.sort(col("numwait").sort(ascending=False), col("s_name").sort()) - -df = df.limit(100) - df.show() From 1878a465acf9ab59b970f7ca0a2160805a4e69c2 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 24 Apr 2026 09:04:07 -0400 Subject: [PATCH 3/5] tpch examples: align reference SQL constants with DataFrame queries The reference SQL embedded in each q01..q22 module docstring was carried over verbatim from ``benchmarks/tpch/queries/`` and uses a different set of TPC-H substitution parameters than the DataFrame examples (answer-file-validated at scale factor 1). Update each reference SQL to use the substitution parameters the DataFrame uses, so both expressions describe the same query and would produce the same results against the same data. Constants aligned: - Q01: ``90 days`` cutoff (DataFrame ``DAYS_BEFORE_FINAL = 90``). - Q02: ``p_size = 15``, ``p_type like '%BRASS'``, ``r_name = 'EUROPE'``. - Q04: base date ``1993-07-01`` (``3 month`` interval preserved per the "quarter of a year" wording). - Q05: ``r_name = 'ASIA'``. - Q06: ``l_discount between 0.06 - 0.01 and 0.06 + 0.01``. - Q07: nations ``'FRANCE'`` / ``'GERMANY'``. - Q08: ``r_name = 'AMERICA'``, ``p_type = 'ECONOMY ANODIZED STEEL'``, inner-case ``nation = 'BRAZIL'``. - Q09: ``p_name like '%green%'``. - Q10: base date ``1993-10-01`` (``3 month`` interval preserved). - Q11: ``n_name = 'GERMANY'``. - Q12: ship modes ``('MAIL', 'SHIP')``, base date ``1994-01-01``. - Q13: ``o_comment not like '%special%requests%'``. - Q14: base date ``1995-09-01``. - Q15: base date ``1996-01-01``. - Q16: ``p_brand <> 'Brand#45'``, ``p_type not like 'MEDIUM POLISHED%'``, sizes ``(49, 14, 23, 45, 19, 3, 36, 9)``. - Q17: ``p_brand = 'Brand#23'``, ``p_container = 'MED BOX'``. - Q18: ``sum(l_quantity) > 300``. - Q19: brands ``Brand#12`` / ``Brand#23`` / ``Brand#34`` with the matching minimum quantities (1, 10, 20). - Q20: ``p_name like 'forest%'``, base date ``1994-01-01``, ``n_name = 'CANADA'``. - Q21: ``n_name = 'SAUDI ARABIA'``. - Q22: country codes ``('13', '31', '23', '29', '30', '18', '17')``. Interval units (month / year) are preserved where the problem-statement text reads "given quarter", "given year", "given month". Q01 keeps the literal "days" unit because the TPC-H problem statement itself describes the cutoff in days. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/tpch/q01_pricing_summary_report.py | 2 +- examples/tpch/q02_minimum_cost_supplier.py | 8 ++++---- examples/tpch/q04_order_priority_checking.py | 4 ++-- examples/tpch/q05_local_supplier_volume.py | 2 +- examples/tpch/q06_forecasting_revenue_change.py | 2 +- examples/tpch/q07_volume_shipping.py | 4 ++-- examples/tpch/q08_market_share.py | 6 +++--- examples/tpch/q09_product_type_profit_measure.py | 2 +- examples/tpch/q10_returned_item_reporting.py | 4 ++-- examples/tpch/q11_important_stock_identification.py | 4 ++-- examples/tpch/q12_ship_mode_order_priority.py | 6 +++--- examples/tpch/q13_customer_distribution.py | 2 +- examples/tpch/q14_promotion_effect.py | 4 ++-- examples/tpch/q15_top_supplier.py | 4 ++-- examples/tpch/q16_part_supplier_relationship.py | 6 +++--- examples/tpch/q17_small_quantity_order.py | 4 ++-- examples/tpch/q18_large_volume_customer.py | 2 +- examples/tpch/q19_discounted_revenue.py | 12 ++++++------ examples/tpch/q20_potential_part_promotion.py | 8 ++++---- examples/tpch/q21_suppliers_kept_orders_waiting.py | 2 +- examples/tpch/q22_global_sales_opportunity.py | 4 ++-- 21 files changed, 46 insertions(+), 46 deletions(-) diff --git a/examples/tpch/q01_pricing_summary_report.py b/examples/tpch/q01_pricing_summary_report.py index 3f963dbe0..1a03b4e1c 100644 --- a/examples/tpch/q01_pricing_summary_report.py +++ b/examples/tpch/q01_pricing_summary_report.py @@ -44,7 +44,7 @@ from lineitem where - l_shipdate <= date '1998-12-01' - interval '68 days' + l_shipdate <= date '1998-12-01' - interval '90 days' group by l_returnflag, l_linestatus diff --git a/examples/tpch/q02_minimum_cost_supplier.py b/examples/tpch/q02_minimum_cost_supplier.py index 303a02a24..cee50e02e 100644 --- a/examples/tpch/q02_minimum_cost_supplier.py +++ b/examples/tpch/q02_minimum_cost_supplier.py @@ -48,11 +48,11 @@ where p_partkey = ps_partkey and s_suppkey = ps_suppkey - and p_size = 48 - and p_type like '%TIN' + and p_size = 15 + and p_type like '%BRASS' and s_nationkey = n_nationkey and n_regionkey = r_regionkey - and r_name = 'ASIA' + and r_name = 'EUROPE' and ps_supplycost = ( select min(ps_supplycost) @@ -66,7 +66,7 @@ and s_suppkey = ps_suppkey and s_nationkey = n_nationkey and n_regionkey = r_regionkey - and r_name = 'ASIA' + and r_name = 'EUROPE' ) order by s_acctbal desc, diff --git a/examples/tpch/q04_order_priority_checking.py b/examples/tpch/q04_order_priority_checking.py index f2ea1f5c9..7e357d054 100644 --- a/examples/tpch/q04_order_priority_checking.py +++ b/examples/tpch/q04_order_priority_checking.py @@ -33,8 +33,8 @@ from orders where - o_orderdate >= date '1995-04-01' - and o_orderdate < date '1995-04-01' + interval '3' month + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month and exists ( select * diff --git a/examples/tpch/q05_local_supplier_volume.py b/examples/tpch/q05_local_supplier_volume.py index 227d5264d..528291596 100644 --- a/examples/tpch/q05_local_supplier_volume.py +++ b/examples/tpch/q05_local_supplier_volume.py @@ -47,7 +47,7 @@ and c_nationkey = s_nationkey and s_nationkey = n_nationkey and n_regionkey = r_regionkey - and r_name = 'AFRICA' + and r_name = 'ASIA' and o_orderdate >= date '1994-01-01' and o_orderdate < date '1994-01-01' + interval '1' year group by diff --git a/examples/tpch/q06_forecasting_revenue_change.py b/examples/tpch/q06_forecasting_revenue_change.py index 2ac2f3401..0ba7d31bd 100644 --- a/examples/tpch/q06_forecasting_revenue_change.py +++ b/examples/tpch/q06_forecasting_revenue_change.py @@ -37,7 +37,7 @@ where l_shipdate >= date '1994-01-01' and l_shipdate < date '1994-01-01' + interval '1' year - and l_discount between 0.04 - 0.01 and 0.04 + 0.01 + and l_discount between 0.06 - 0.01 and 0.06 + 0.01 and l_quantity < 24; """ diff --git a/examples/tpch/q07_volume_shipping.py b/examples/tpch/q07_volume_shipping.py index 0823369a8..bbbd44e0c 100644 --- a/examples/tpch/q07_volume_shipping.py +++ b/examples/tpch/q07_volume_shipping.py @@ -55,8 +55,8 @@ and s_nationkey = n1.n_nationkey and c_nationkey = n2.n_nationkey and ( - (n1.n_name = 'GERMANY' and n2.n_name = 'IRAQ') - or (n1.n_name = 'IRAQ' and n2.n_name = 'GERMANY') + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') ) and l_shipdate between date '1995-01-01' and date '1996-12-31' ) as shipping diff --git a/examples/tpch/q08_market_share.py b/examples/tpch/q08_market_share.py index d2e034f41..0a6b2cb00 100644 --- a/examples/tpch/q08_market_share.py +++ b/examples/tpch/q08_market_share.py @@ -31,7 +31,7 @@ select o_year, sum(case - when nation = 'IRAQ' then volume + when nation = 'BRAZIL' then volume else 0 end) / sum(volume) as mkt_share from @@ -56,10 +56,10 @@ and o_custkey = c_custkey and c_nationkey = n1.n_nationkey and n1.n_regionkey = r_regionkey - and r_name = 'MIDDLE EAST' + and r_name = 'AMERICA' and s_nationkey = n2.n_nationkey and o_orderdate between date '1995-01-01' and date '1996-12-31' - and p_type = 'LARGE PLATED STEEL' + and p_type = 'ECONOMY ANODIZED STEEL' ) as all_nations group by o_year diff --git a/examples/tpch/q09_product_type_profit_measure.py b/examples/tpch/q09_product_type_profit_measure.py index 7264f78ce..7951a7672 100644 --- a/examples/tpch/q09_product_type_profit_measure.py +++ b/examples/tpch/q09_product_type_profit_measure.py @@ -54,7 +54,7 @@ and p_partkey = l_partkey and o_orderkey = l_orderkey and s_nationkey = n_nationkey - and p_name like '%moccasin%' + and p_name like '%green%' ) as profit group by nation, diff --git a/examples/tpch/q10_returned_item_reporting.py b/examples/tpch/q10_returned_item_reporting.py index ca25f6a88..6888204ce 100644 --- a/examples/tpch/q10_returned_item_reporting.py +++ b/examples/tpch/q10_returned_item_reporting.py @@ -47,8 +47,8 @@ where c_custkey = o_custkey and l_orderkey = o_orderkey - and o_orderdate >= date '1993-07-01' - and o_orderdate < date '1993-07-01' + interval '3' month + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1993-10-01' + interval '3' month and l_returnflag = 'R' and c_nationkey = n_nationkey group by diff --git a/examples/tpch/q11_important_stock_identification.py b/examples/tpch/q11_important_stock_identification.py index 8b67091b2..3828a50ae 100644 --- a/examples/tpch/q11_important_stock_identification.py +++ b/examples/tpch/q11_important_stock_identification.py @@ -38,7 +38,7 @@ where ps_suppkey = s_suppkey and s_nationkey = n_nationkey - and n_name = 'ALGERIA' + and n_name = 'GERMANY' group by ps_partkey having sum(ps_supplycost * ps_availqty) > ( @@ -51,7 +51,7 @@ where ps_suppkey = s_suppkey and s_nationkey = n_nationkey - and n_name = 'ALGERIA' + and n_name = 'GERMANY' ) order by value desc; diff --git a/examples/tpch/q12_ship_mode_order_priority.py b/examples/tpch/q12_ship_mode_order_priority.py index b3f4c7034..159c41cfa 100644 --- a/examples/tpch/q12_ship_mode_order_priority.py +++ b/examples/tpch/q12_ship_mode_order_priority.py @@ -49,11 +49,11 @@ lineitem where o_orderkey = l_orderkey - and l_shipmode in ('FOB', 'SHIP') + and l_shipmode in ('MAIL', 'SHIP') and l_commitdate < l_receiptdate and l_shipdate < l_commitdate - and l_receiptdate >= date '1995-01-01' - and l_receiptdate < date '1995-01-01' + interval '1' year + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1994-01-01' + interval '1' year group by l_shipmode order by diff --git a/examples/tpch/q13_customer_distribution.py b/examples/tpch/q13_customer_distribution.py index df1f0884f..206927727 100644 --- a/examples/tpch/q13_customer_distribution.py +++ b/examples/tpch/q13_customer_distribution.py @@ -40,7 +40,7 @@ from customer left outer join orders on c_custkey = o_custkey - and o_comment not like '%express%requests%' + and o_comment not like '%special%requests%' group by c_custkey ) as c_orders (c_custkey, c_count) diff --git a/examples/tpch/q14_promotion_effect.py b/examples/tpch/q14_promotion_effect.py index a3be0e9b8..67aa5bd1c 100644 --- a/examples/tpch/q14_promotion_effect.py +++ b/examples/tpch/q14_promotion_effect.py @@ -38,8 +38,8 @@ part where l_partkey = p_partkey - and l_shipdate >= date '1995-02-01' - and l_shipdate < date '1995-02-01' + interval '1' month; + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-09-01' + interval '1' month; """ from datetime import datetime diff --git a/examples/tpch/q15_top_supplier.py b/examples/tpch/q15_top_supplier.py index 285bee497..4b1c4c8c8 100644 --- a/examples/tpch/q15_top_supplier.py +++ b/examples/tpch/q15_top_supplier.py @@ -34,8 +34,8 @@ from lineitem where - l_shipdate >= date '1996-08-01' - and l_shipdate < date '1996-08-01' + interval '3' month + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month group by l_suppkey; select diff --git a/examples/tpch/q16_part_supplier_relationship.py b/examples/tpch/q16_part_supplier_relationship.py index 1875242f5..af6e255fd 100644 --- a/examples/tpch/q16_part_supplier_relationship.py +++ b/examples/tpch/q16_part_supplier_relationship.py @@ -39,9 +39,9 @@ part where p_partkey = ps_partkey - and p_brand <> 'Brand#14' - and p_type not like 'SMALL PLATED%' - and p_size in (14, 6, 5, 31, 49, 15, 41, 47) + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) and ps_suppkey not in ( select s_suppkey diff --git a/examples/tpch/q17_small_quantity_order.py b/examples/tpch/q17_small_quantity_order.py index 29a8dfbef..0f0f575b6 100644 --- a/examples/tpch/q17_small_quantity_order.py +++ b/examples/tpch/q17_small_quantity_order.py @@ -36,8 +36,8 @@ part where p_partkey = l_partkey - and p_brand = 'Brand#42' - and p_container = 'LG BAG' + and p_brand = 'Brand#23' + and p_container = 'MED BOX' and l_quantity < ( select 0.2 * avg(l_quantity) diff --git a/examples/tpch/q18_large_volume_customer.py b/examples/tpch/q18_large_volume_customer.py index 0caf0ebd6..55e3b71f9 100644 --- a/examples/tpch/q18_large_volume_customer.py +++ b/examples/tpch/q18_large_volume_customer.py @@ -46,7 +46,7 @@ lineitem group by l_orderkey having - sum(l_quantity) > 313 + sum(l_quantity) > 300 ) and c_custkey = o_custkey and o_orderkey = l_orderkey diff --git a/examples/tpch/q19_discounted_revenue.py b/examples/tpch/q19_discounted_revenue.py index 69e732c9c..5f87e6aa2 100644 --- a/examples/tpch/q19_discounted_revenue.py +++ b/examples/tpch/q19_discounted_revenue.py @@ -35,9 +35,9 @@ where ( p_partkey = l_partkey - and p_brand = 'Brand#21' + and p_brand = 'Brand#12' and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') - and l_quantity >= 8 and l_quantity <= 8 + 10 + and l_quantity >= 1 and l_quantity <= 1 + 10 and p_size between 1 and 5 and l_shipmode in ('AIR', 'AIR REG') and l_shipinstruct = 'DELIVER IN PERSON' @@ -45,9 +45,9 @@ or ( p_partkey = l_partkey - and p_brand = 'Brand#13' + and p_brand = 'Brand#23' and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') - and l_quantity >= 20 and l_quantity <= 20 + 10 + and l_quantity >= 10 and l_quantity <= 10 + 10 and p_size between 1 and 10 and l_shipmode in ('AIR', 'AIR REG') and l_shipinstruct = 'DELIVER IN PERSON' @@ -55,9 +55,9 @@ or ( p_partkey = l_partkey - and p_brand = 'Brand#52' + and p_brand = 'Brand#34' and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') - and l_quantity >= 30 and l_quantity <= 30 + 10 + and l_quantity >= 20 and l_quantity <= 20 + 10 and p_size between 1 and 15 and l_shipmode in ('AIR', 'AIR REG') and l_shipinstruct = 'DELIVER IN PERSON' diff --git a/examples/tpch/q20_potential_part_promotion.py b/examples/tpch/q20_potential_part_promotion.py index 625943b96..47a60fe79 100644 --- a/examples/tpch/q20_potential_part_promotion.py +++ b/examples/tpch/q20_potential_part_promotion.py @@ -47,7 +47,7 @@ from part where - p_name like 'blanched%' + p_name like 'forest%' ) and ps_availqty > ( select @@ -57,12 +57,12 @@ where l_partkey = ps_partkey and l_suppkey = ps_suppkey - and l_shipdate >= date '1993-01-01' - and l_shipdate < date '1993-01-01' + interval '1' year + and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year ) ) and s_nationkey = n_nationkey - and n_name = 'KENYA' + and n_name = 'CANADA' order by s_name; """ diff --git a/examples/tpch/q21_suppliers_kept_orders_waiting.py b/examples/tpch/q21_suppliers_kept_orders_waiting.py index b2113187b..7ec5c3069 100644 --- a/examples/tpch/q21_suppliers_kept_orders_waiting.py +++ b/examples/tpch/q21_suppliers_kept_orders_waiting.py @@ -60,7 +60,7 @@ and l3.l_receiptdate > l3.l_commitdate ) and s_nationkey = n_nationkey - and n_name = 'ARGENTINA' + and n_name = 'SAUDI ARABIA' group by s_name order by diff --git a/examples/tpch/q22_global_sales_opportunity.py b/examples/tpch/q22_global_sales_opportunity.py index a294fb5d5..5f463ab45 100644 --- a/examples/tpch/q22_global_sales_opportunity.py +++ b/examples/tpch/q22_global_sales_opportunity.py @@ -40,7 +40,7 @@ customer where substring(c_phone from 1 for 2) in - ('24', '34', '16', '30', '33', '14', '13') + ('13', '31', '23', '29', '30', '18', '17') and c_acctbal > ( select avg(c_acctbal) @@ -49,7 +49,7 @@ where c_acctbal > 0.00 and substring(c_phone from 1 for 2) in - ('24', '34', '16', '30', '33', '14', '13') + ('13', '31', '23', '29', '30', '18', '17') ) and not exists ( select From a0c0fb9ecf07c50984d1c1db1a5247d6df6dd1e4 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 24 Apr 2026 09:21:39 -0400 Subject: [PATCH 4/5] tpch examples: apply SKILL.md idioms across all 22 queries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sweep every q01..q22 example for idiomatic DataFrame style as described in the repo-root SKILL.md: - ``col("x") == "s"`` in place of ``col("x") == lit("s")`` on comparison right-hand sides (auto-wrap applies). - Plain-name strings in ``select``/``aggregate``/``sort`` group/sort key lists when the key is a bare column. - Drop redundant ``how="inner"`` and single-element ``left_on``/``right_on`` list wrapping on equi-joins. - Collapse chained ``.filter(a).filter(b)`` runs into ``.filter(a, b)`` and chained ``.with_column`` runs into ``.with_columns(a=..., b=...)``. - ``df.sort_by(...)`` or plain-name ``df.sort(...)`` when no null-placement override is needed. - ``F.count_star()`` in place of ``F.count(col("x"))`` whenever the SQL reads ``count(*)``. - ``F.starts_with(col, lit(prefix))`` and ``~F.starts_with(...)`` in place of substring-prefix equality/inequality tricks. - ``F.in_list(col, [lit(...)])`` in place of ``~F.array_position(...). is_null()`` and in place of disjunctions of equality comparisons. - Searched ``F.when(cond, x).otherwise(y)`` in place of switched ``F.case(bool_expr).when(lit(True/False), x).end()`` forms. - Semi-joins as the DataFrame form of ``EXISTS`` (Q04); anti-joins as ``NOT EXISTS`` (Q22 was already using this idiom). - Whole-frame window aggregates as the DataFrame stand-in for a SQL scalar subquery (Q11/Q15/Q17/Q22). Individual query fixes of note: - Q16 — add the secondary sort keys (``p_brand``, ``p_type``, ``p_size``) that the TPC-H spec requires but the original DataFrame omitted. - Q22 — drop a stray ``df.show()`` mid-pipeline; replace the 0-based substring slice with ``F.left(col("c_phone"), lit(2))``. - Q14 — rewrite the promo/non-promo factor split as a searched CASE inside ``F.sum(...)`` so the DataFrame expression matches the reference SQL shape exactly. All 22 answer-file comparisons still pass at scale factor 1. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/tpch/q01_pricing_summary_report.py | 20 ++-- examples/tpch/q02_minimum_cost_supplier.py | 35 +++--- examples/tpch/q03_shipping_priority.py | 26 ++--- examples/tpch/q04_order_priority_checking.py | 30 ++--- examples/tpch/q05_local_supplier_volume.py | 28 ++--- .../tpch/q06_forecasting_revenue_change.py | 11 +- examples/tpch/q07_volume_shipping.py | 36 +++--- examples/tpch/q08_market_share.py | 105 ++++++------------ .../tpch/q09_product_type_profit_measure.py | 39 +++---- examples/tpch/q10_returned_item_reporting.py | 58 +++++----- .../q11_important_stock_identification.py | 53 ++++----- examples/tpch/q12_ship_mode_order_priority.py | 60 ++++------ examples/tpch/q13_customer_distribution.py | 24 ++-- examples/tpch/q14_promotion_effect.py | 41 +++---- examples/tpch/q15_top_supplier.py | 45 ++++---- .../tpch/q16_part_supplier_relationship.py | 48 ++++---- examples/tpch/q17_small_quantity_order.py | 40 +++---- examples/tpch/q18_large_volume_customer.py | 36 +++--- examples/tpch/q19_discounted_revenue.py | 28 ++--- examples/tpch/q20_potential_part_promotion.py | 24 ++-- .../tpch/q21_suppliers_kept_orders_waiting.py | 45 +++----- examples/tpch/q22_global_sales_opportunity.py | 64 +++++------ 22 files changed, 374 insertions(+), 522 deletions(-) diff --git a/examples/tpch/q01_pricing_summary_report.py b/examples/tpch/q01_pricing_summary_report.py index 1a03b4e1c..105f1632d 100644 --- a/examples/tpch/q01_pricing_summary_report.py +++ b/examples/tpch/q01_pricing_summary_report.py @@ -82,31 +82,25 @@ # Aggregate the results +disc_price = col("l_extendedprice") * (lit(1) - col("l_discount")) + df = df.aggregate( - [col("l_returnflag"), col("l_linestatus")], + ["l_returnflag", "l_linestatus"], [ F.sum(col("l_quantity")).alias("sum_qty"), F.sum(col("l_extendedprice")).alias("sum_base_price"), - F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias( - "sum_disc_price" - ), - F.sum( - col("l_extendedprice") - * (lit(1) - col("l_discount")) - * (lit(1) + col("l_tax")) - ).alias("sum_charge"), + F.sum(disc_price).alias("sum_disc_price"), + F.sum(disc_price * (lit(1) + col("l_tax"))).alias("sum_charge"), F.avg(col("l_quantity")).alias("avg_qty"), F.avg(col("l_extendedprice")).alias("avg_price"), F.avg(col("l_discount")).alias("avg_disc"), - F.count(col("l_returnflag")).alias( - "count_order" - ), # Counting any column should return same result + F.count_star().alias("count_order"), ], ) # Sort per the expected result -df = df.sort(col("l_returnflag").sort(), col("l_linestatus").sort()) +df = df.sort_by("l_returnflag", "l_linestatus") # Note: There appears to be a discrepancy between what is returned here and what is in the generated # answers file for the case of return flag N and line status O, but I did not investigate further. diff --git a/examples/tpch/q02_minimum_cost_supplier.py b/examples/tpch/q02_minimum_cost_supplier.py index cee50e02e..6c26d262e 100644 --- a/examples/tpch/q02_minimum_cost_supplier.py +++ b/examples/tpch/q02_minimum_cost_supplier.py @@ -118,30 +118,25 @@ # in the string where it is located. df_part = df_part.filter( - F.strpos(col("p_type"), lit(TYPE_OF_INTEREST)) > lit(0) -).filter(col("p_size") == lit(SIZE_OF_INTEREST)) + F.strpos(col("p_type"), lit(TYPE_OF_INTEREST)) > 0, + col("p_size") == SIZE_OF_INTEREST, +) # Filter regions down to the one of interest -df_region = df_region.filter(col("r_name") == lit(REGION_OF_INTEREST)) +df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST) # Now that we have the region, find suppliers in that region. Suppliers are tied to their nation # and nations are tied to the region. -df_nation = df_nation.join( - df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner" -) -df_supplier = df_supplier.join( - df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" -) +df_nation = df_nation.join(df_region, left_on="n_regionkey", right_on="r_regionkey") +df_supplier = df_supplier.join(df_nation, left_on="s_nationkey", right_on="n_nationkey") # Now that we know who the potential suppliers are for the part, we can limit out part # supplies table down. We can further join down to the specific parts we've identified # as matching the request -df = df_partsupp.join( - df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner" -) +df = df_partsupp.join(df_supplier, left_on="ps_suppkey", right_on="s_suppkey") # Locate the minimum cost across all suppliers. There are multiple ways you could do this, # but one way is to create a window function across all suppliers, find the minimum, and @@ -158,9 +153,9 @@ ), ) -df = df.filter(col("min_cost") == col("ps_supplycost")) - -df = df.join(df_part, left_on=["ps_partkey"], right_on=["p_partkey"], how="inner") +df = df.filter(col("min_cost") == col("ps_supplycost")).join( + df_part, left_on="ps_partkey", right_on="p_partkey" +) # From the problem statement, these are the values we wish to output @@ -178,12 +173,10 @@ # Sort and display 100 entries df = df.sort( col("s_acctbal").sort(ascending=False), - col("n_name").sort(), - col("s_name").sort(), - col("p_partkey").sort(), -) - -df = df.limit(100) + "n_name", + "s_name", + "p_partkey", +).limit(100) # Show results diff --git a/examples/tpch/q03_shipping_priority.py b/examples/tpch/q03_shipping_priority.py index 6dc1da42f..880c7435f 100644 --- a/examples/tpch/q03_shipping_priority.py +++ b/examples/tpch/q03_shipping_priority.py @@ -75,20 +75,20 @@ # Limit dataframes to the rows of interest -df_customer = df_customer.filter(col("c_mktsegment") == lit(SEGMENT_OF_INTEREST)) +df_customer = df_customer.filter(col("c_mktsegment") == SEGMENT_OF_INTEREST) df_orders = df_orders.filter(col("o_orderdate") < lit(DATE_OF_INTEREST)) df_lineitem = df_lineitem.filter(col("l_shipdate") > lit(DATE_OF_INTEREST)) # Join all 3 dataframes -df = df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" -).join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") +df = df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey").join( + df_lineitem, left_on="o_orderkey", right_on="l_orderkey" +) # Compute the revenue df = df.aggregate( - [col("l_orderkey")], + ["l_orderkey"], [ F.first_value(col("o_orderdate")).alias("o_orderdate"), F.first_value(col("o_shippriority")).alias("o_shippriority"), @@ -96,17 +96,13 @@ ], ) -# Sort by priority - -df = df.sort(col("revenue").sort(ascending=False), col("o_orderdate").sort()) - -# Only return 10 results +# Sort by priority, take 10, and project in the order expected by the spec. -df = df.limit(10) - -# Change the order that the columns are reported in just to match the spec - -df = df.select("l_orderkey", "revenue", "o_orderdate", "o_shippriority") +df = ( + df.sort(col("revenue").sort(ascending=False), "o_orderdate") + .limit(10) + .select("l_orderkey", "revenue", "o_orderdate", "o_shippriority") +) # Show result diff --git a/examples/tpch/q04_order_priority_checking.py b/examples/tpch/q04_order_priority_checking.py index 7e357d054..18cbb2054 100644 --- a/examples/tpch/q04_order_priority_checking.py +++ b/examples/tpch/q04_order_priority_checking.py @@ -77,31 +77,23 @@ interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) -# Limit results to cases where commitment date before receipt date, then -# reduce to a single row per order so the join with the orders table is a -# semantic EXISTS rather than a fan-out. -df_lineitem = ( - df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")) - .select("l_orderkey") - .distinct() +# Keep only orders in the quarter of interest, then restrict to those that +# have at least one late lineitem via a semi join (the DataFrame form of +# ``EXISTS`` from the reference SQL). +df_orders = df_orders.filter( + col("o_orderdate") >= lit(date), + col("o_orderdate") < lit(date) + lit(interval), ) -# Limit orders to date range of interest -df_orders = df_orders.filter(col("o_orderdate") >= lit(date)).filter( - col("o_orderdate") < lit(date) + lit(interval) -) +late_lineitems = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")) -# Perform the join to find only orders for which there are lineitems outside of expected range df = df_orders.join( - df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner" + late_lineitems, left_on="o_orderkey", right_on="l_orderkey", how="semi" ) -# Based on priority, find the number of entries -df = df.aggregate( - [col("o_orderpriority")], [F.count(col("o_orderpriority")).alias("order_count")] +# Count the number of orders in each priority group and sort. +df = df.aggregate(["o_orderpriority"], [F.count_star().alias("order_count")]).sort_by( + "o_orderpriority" ) -# Sort the results -df = df.sort(col("o_orderpriority").sort()) - df.show() diff --git a/examples/tpch/q05_local_supplier_volume.py b/examples/tpch/q05_local_supplier_volume.py index 528291596..5e648f272 100644 --- a/examples/tpch/q05_local_supplier_volume.py +++ b/examples/tpch/q05_local_supplier_volume.py @@ -95,38 +95,32 @@ ) # Restrict dataframes to cases of interest -df_orders = df_orders.filter(col("o_orderdate") >= lit(date)).filter( - col("o_orderdate") < lit(date) + lit(interval) +df_orders = df_orders.filter( + col("o_orderdate") >= lit(date), + col("o_orderdate") < lit(date) + lit(interval), ) -df_region = df_region.filter(col("r_name") == lit(REGION_OF_INTEREST)) +df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST) # Join all the dataframes df = ( - df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" - ) - .join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") + df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey") + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") .join( df_supplier, left_on=["l_suppkey", "c_nationkey"], right_on=["s_suppkey", "s_nationkey"], - how="inner", ) - .join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") - .join(df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner") + .join(df_nation, left_on="s_nationkey", right_on="n_nationkey") + .join(df_region, left_on="n_regionkey", right_on="r_regionkey") ) -# Compute the final result +# Compute the final result, then sort in descending order. df = df.aggregate( - [col("n_name")], + ["n_name"], [F.sum(col("l_extendedprice") * (lit(1.0) - col("l_discount"))).alias("revenue")], -) - -# Sort in descending order - -df = df.sort(col("revenue").sort(ascending=False)) +).sort(col("revenue").sort(ascending=False)) df.show() diff --git a/examples/tpch/q06_forecasting_revenue_change.py b/examples/tpch/q06_forecasting_revenue_change.py index 0ba7d31bd..79697f833 100644 --- a/examples/tpch/q06_forecasting_revenue_change.py +++ b/examples/tpch/q06_forecasting_revenue_change.py @@ -71,12 +71,11 @@ # Filter down to lineitems of interest -df = ( - df_lineitem.filter(col("l_shipdate") >= lit(date)) - .filter(col("l_shipdate") < lit(date) + lit(interval)) - .filter(col("l_discount") >= lit(DISCOUT) - lit(DELTA)) - .filter(col("l_discount") <= lit(DISCOUT) + lit(DELTA)) - .filter(col("l_quantity") < lit(QUANTITY)) +df = df_lineitem.filter( + col("l_shipdate") >= lit(date), + col("l_shipdate") < lit(date) + lit(interval), + col("l_discount").between(lit(DISCOUT - DELTA), lit(DISCOUT + DELTA)), + col("l_quantity") < QUANTITY, ) # Add up all the "lost" revenue diff --git a/examples/tpch/q07_volume_shipping.py b/examples/tpch/q07_volume_shipping.py index bbbd44e0c..6584509cc 100644 --- a/examples/tpch/q07_volume_shipping.py +++ b/examples/tpch/q07_volume_shipping.py @@ -111,8 +111,8 @@ # Filter to time of interest -df_lineitem = df_lineitem.filter(col("l_shipdate") >= start_date).filter( - col("l_shipdate") <= end_date +df_lineitem = df_lineitem.filter( + col("l_shipdate") >= start_date, col("l_shipdate") <= end_date ) @@ -122,37 +122,33 @@ # Limit suppliers to either nation df_supplier = df_supplier.join( - df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" -).select(col("s_suppkey"), col("n_name").alias("supp_nation")) + df_nation, left_on="s_nationkey", right_on="n_nationkey" +).select("s_suppkey", col("n_name").alias("supp_nation")) # Limit customers to either nation df_customer = df_customer.join( - df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner" -).select(col("c_custkey"), col("n_name").alias("cust_nation")) + df_nation, left_on="c_nationkey", right_on="n_nationkey" +).select("c_custkey", col("n_name").alias("cust_nation")) # Join up all the data frames from line items, and make sure the supplier and customer are in # different nations. df = ( - df_lineitem.join( - df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner" - ) - .join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") - .join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") + df_lineitem.join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join(df_customer, left_on="o_custkey", right_on="c_custkey") + .join(df_supplier, left_on="l_suppkey", right_on="s_suppkey") .filter(col("cust_nation") != col("supp_nation")) ) # Extract out two values for every line item -df = df.with_column( - "l_year", F.datepart(lit("year"), col("l_shipdate")).cast(pa.int32()) -).with_column("volume", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) +df = df.with_columns( + l_year=F.datepart(lit("year"), col("l_shipdate")).cast(pa.int32()), + volume=col("l_extendedprice") * (lit(1.0) - col("l_discount")), +) -# Aggregate the results +# Aggregate and sort per the spec. df = df.aggregate( - [col("supp_nation"), col("cust_nation"), col("l_year")], + ["supp_nation", "cust_nation", "l_year"], [F.sum(col("volume")).alias("revenue")], -) - -# Sort based on problem statement requirements -df = df.sort(col("supp_nation").sort(), col("cust_nation").sort(), col("l_year").sort()) +).sort_by("supp_nation", "cust_nation", "l_year") df.show() diff --git a/examples/tpch/q08_market_share.py b/examples/tpch/q08_market_share.py index 0a6b2cb00..fbf0fc8e5 100644 --- a/examples/tpch/q08_market_share.py +++ b/examples/tpch/q08_market_share.py @@ -114,8 +114,8 @@ # Limit orders to those in the specified range -df_orders = df_orders.filter(col("o_orderdate") >= start_date).filter( - col("o_orderdate") <= end_date +df_orders = df_orders.filter( + col("o_orderdate") >= start_date, col("o_orderdate") <= end_date ) # Part 1: Find customers in the region @@ -127,36 +127,14 @@ # First we find all the sales that make up the basis. -df_regional_customers = df_region.filter(col("r_name") == customer_region) - -# After this join we have all of the possible sales nations -df_regional_customers = df_regional_customers.join( - df_nation, left_on=["r_regionkey"], right_on=["n_regionkey"], how="inner" -) - -# Now find the possible customers -df_regional_customers = df_regional_customers.join( - df_customer, left_on=["n_nationkey"], right_on=["c_nationkey"], how="inner" -) - -# Next find orders for these customers -df_regional_customers = df_regional_customers.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" -) - -# Find all line items from these orders -df_regional_customers = df_regional_customers.join( - df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner" -) - -# Limit to the part of interest -df_regional_customers = df_regional_customers.join( - df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner" -) - -# Compute the volume for each line item -df_regional_customers = df_regional_customers.with_column( - "volume", col("l_extendedprice") * (lit(1.0) - col("l_discount")) +df_regional_customers = ( + df_region.filter(col("r_name") == customer_region) + .join(df_nation, left_on="r_regionkey", right_on="n_regionkey") + .join(df_customer, left_on="n_nationkey", right_on="c_nationkey") + .join(df_orders, left_on="c_custkey", right_on="o_custkey") + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") + .join(df_part, left_on="l_partkey", right_on="p_partkey") + .with_column("volume", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) ) # Part 2: Find suppliers from the nation @@ -164,56 +142,41 @@ # Now that we have all of the sales of that part in the specified region, we need # to determine which of those came from suppliers in the nation we are interested in. -df_national_suppliers = df_nation.filter(col("n_name") == supplier_nation) - -# Determine the suppliers by the limited nation key we have in our single row df above -df_national_suppliers = df_national_suppliers.join( - df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner" +df_national_suppliers = ( + df_nation.filter(col("n_name") == supplier_nation) + .join(df_supplier, left_on="n_nationkey", right_on="s_nationkey") + .select("s_suppkey") ) -# When we join to the customer dataframe, we don't want to confuse other columns, so only -# select the supplier key that we need -df_national_suppliers = df_national_suppliers.select("s_suppkey") - # Part 3: Combine suppliers and customers and compute the market share -# Now we can do a left outer join on the suppkey. Those line items from other suppliers -# will get a null value. We can check for the existence of this null to compute a volume -# column only from suppliers in the nation we are evaluating. +# Left-outer join the national suppliers onto the regional sales. Rows from +# other suppliers get a NULL ``s_suppkey``, which the CASE expression uses +# to zero out the non-national volume. df = df_regional_customers.join( - df_national_suppliers, left_on=["l_suppkey"], right_on=["s_suppkey"], how="left" -) - -# Use a searched CASE (``F.when(...).otherwise(...)``) to keep only the -# volume attributable to suppliers in the nation of interest. This mirrors -# the ``case when nation = '...' then volume else 0 end`` form of the -# reference SQL rather than dispatching on a boolean subject. -df = df.with_column( - "national_volume", - F.when(col("s_suppkey").is_not_null(), col("volume")).otherwise(lit(0.0)), -) - -df = df.with_column( - "o_year", F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()) + df_national_suppliers, left_on="l_suppkey", right_on="s_suppkey", how="left" +).with_columns( + national_volume=F.when(col("s_suppkey").is_not_null(), col("volume")).otherwise( + lit(0.0) + ), + o_year=F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()), ) -# Lastly, sum up the results +# Aggregate, compute the share, and sort. -df = df.aggregate( - [col("o_year")], - [ - F.sum(col("volume")).alias("volume"), - F.sum(col("national_volume")).alias("national_volume"), - ], +df = ( + df.aggregate( + ["o_year"], + [ + F.sum(col("volume")).alias("volume"), + F.sum(col("national_volume")).alias("national_volume"), + ], + ) + .select("o_year", (col("national_volume") / col("volume")).alias("mkt_share")) + .sort_by("o_year") ) -df = df.select( - col("o_year"), (F.col("national_volume") / F.col("volume")).alias("mkt_share") -) - -df = df.sort(col("o_year").sort()) - df.show() diff --git a/examples/tpch/q09_product_type_profit_measure.py b/examples/tpch/q09_product_type_profit_measure.py index 7951a7672..139c0483f 100644 --- a/examples/tpch/q09_product_type_profit_measure.py +++ b/examples/tpch/q09_product_type_profit_measure.py @@ -97,37 +97,34 @@ "n_nationkey", "n_name", "n_regionkey" ) -# Limit possible parts to the color specified -df = df_part.filter(F.strpos(col("p_name"), part_color) > lit(0)) - -# We have a series of joins that get us to limit down to the line items we need -df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner") -df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") -df = df.join( - df_partsupp, - left_on=["l_suppkey", "l_partkey"], - right_on=["ps_suppkey", "ps_partkey"], - how="inner", +# Limit possible parts to the color specified, then walk the joins down to the +# line-item rows we need and attach the supplier's nation. +df = ( + df_part.filter(F.strpos(col("p_name"), part_color) > 0) + .join(df_lineitem, left_on="p_partkey", right_on="l_partkey") + .join(df_supplier, left_on="l_suppkey", right_on="s_suppkey") + .join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join( + df_partsupp, + left_on=["l_suppkey", "l_partkey"], + right_on=["ps_suppkey", "ps_partkey"], + ) + .join(df_nation, left_on="s_nationkey", right_on="n_nationkey") ) -df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") # Compute the intermediate values and limit down to the expressions we need df = df.select( col("n_name").alias("nation"), F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()).alias("o_year"), ( - (col("l_extendedprice") * (lit(1) - col("l_discount"))) - - (col("ps_supplycost") * col("l_quantity")) + col("l_extendedprice") * (lit(1) - col("l_discount")) + - col("ps_supplycost") * col("l_quantity") ).alias("amount"), ) -# Sum up the values by nation and year -df = df.aggregate( - [col("nation"), col("o_year")], [F.sum(col("amount")).alias("profit")] +# Sum up the values by nation and year, then sort per the spec. +df = df.aggregate(["nation", "o_year"], [F.sum(col("amount")).alias("profit")]).sort( + "nation", col("o_year").sort(ascending=False) ) -# Sort according to the problem specification -df = df.sort(col("nation").sort(), col("o_year").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q10_returned_item_reporting.py b/examples/tpch/q10_returned_item_reporting.py index 6888204ce..3e4e8a4d9 100644 --- a/examples/tpch/q10_returned_item_reporting.py +++ b/examples/tpch/q10_returned_item_reporting.py @@ -100,44 +100,40 @@ ) # limit to returns -df_lineitem = df_lineitem.filter(col("l_returnflag") == lit("R")) +df_lineitem = df_lineitem.filter(col("l_returnflag") == "R") # Rather than aggregate by all of the customer fields as you might do looking at the specification, # we can aggregate by o_custkey and then join in the customer data at the end. -df = df_orders.filter(col("o_orderdate") >= date_start_of_quarter).filter( - col("o_orderdate") < date_start_of_quarter + interval_one_quarter +df = ( + df_orders.filter( + col("o_orderdate") >= date_start_of_quarter, + col("o_orderdate") < date_start_of_quarter + interval_one_quarter, + ) + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") + .aggregate( + ["o_custkey"], + [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], + ) ) -df = df.join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") - -# Compute the revenue -df = df.aggregate( - [col("o_custkey")], - [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], -) - -# Now join in the customer data -df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") -df = df.join(df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner") - -# These are the columns the problem statement requires -df = df.select( - "c_custkey", - "c_name", - "revenue", - "c_acctbal", - "n_name", - "c_address", - "c_phone", - "c_comment", +# Now join in the customer data, project the spec's output columns, and take the top 20. +df = ( + df.join(df_customer, left_on="o_custkey", right_on="c_custkey") + .join(df_nation, left_on="c_nationkey", right_on="n_nationkey") + .select( + "c_custkey", + "c_name", + "revenue", + "c_acctbal", + "n_name", + "c_address", + "c_phone", + "c_comment", + ) + .sort(col("revenue").sort(ascending=False)) + .limit(20) ) -# Sort the results in descending order -df = df.sort(col("revenue").sort(ascending=False)) - -# Only return the top 20 results -df = df.limit(20) - df.show() diff --git a/examples/tpch/q11_important_stock_identification.py b/examples/tpch/q11_important_stock_identification.py index 3828a50ae..1f40bbdad 100644 --- a/examples/tpch/q11_important_stock_identification.py +++ b/examples/tpch/q11_important_stock_identification.py @@ -79,39 +79,30 @@ "n_nationkey", "n_name" ) -# limit to returns -df_nation = df_nation.filter(col("n_name") == lit(NATION)) - -# Find part supplies of within this target nation - -df = df_nation.join( - df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner" +# Restrict to the target nation, then walk to partsupp rows via the supplier +# join. Aggregate the per-part inventory value. +df = ( + df_nation.filter(col("n_name") == NATION) + .join(df_supplier, left_on="n_nationkey", right_on="s_nationkey") + .join(df_partsupp, left_on="s_suppkey", right_on="ps_suppkey") + .with_column("value", col("ps_supplycost") * col("ps_availqty")) + .aggregate(["ps_partkey"], [F.sum(col("value")).alias("value")]) ) -df = df.join(df_partsupp, left_on=["s_suppkey"], right_on=["ps_suppkey"], how="inner") - - -# Compute the value of individual parts -df = df.with_column("value", col("ps_supplycost") * col("ps_availqty")) - -# Compute total value of specific parts -df = df.aggregate([col("ps_partkey")], [F.sum(col("value")).alias("value")]) - -# By default window functions go from unbounded preceding to current row, but we want -# to compute this sum across all rows -window_frame = WindowFrame("rows", None, None) - -df = df.with_column( - "total_value", F.sum(col("value")).over(Window(window_frame=window_frame)) +# A window function evaluated over the entire output produces a scalar grand +# total that can be referenced row-by-row in the filter — a DataFrame-native +# stand-in for the SQL HAVING ... > (SELECT SUM(...) * FRACTION ...) pattern. +# The default frame is "UNBOUNDED PRECEDING to CURRENT ROW"; override to the +# full partition for the grand total. +whole_frame = WindowFrame("rows", None, None) + +df = ( + df.with_column( + "total_value", F.sum(col("value")).over(Window(window_frame=whole_frame)) + ) + .filter(col("value") / col("total_value") >= lit(FRACTION)) + .select("ps_partkey", "value") + .sort(col("value").sort(ascending=False)) ) -# Limit to the parts for which there is a significant value based on the fraction of the total -df = df.filter(col("value") / col("total_value") >= lit(FRACTION)) - -# We only need to report on these two columns -df = df.select("ps_partkey", "value") - -# Sort in descending order of value -df = df.sort(col("value").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q12_ship_mode_order_priority.py b/examples/tpch/q12_ship_mode_order_priority.py index 159c41cfa..c684e1ba5 100644 --- a/examples/tpch/q12_ship_mode_order_priority.py +++ b/examples/tpch/q12_ship_mode_order_priority.py @@ -87,47 +87,29 @@ interval = pa.scalar((0, 365, 0), type=pa.month_day_nano_interval()) -df = df_lineitem.filter(col("l_receiptdate") >= lit(date)).filter( - col("l_receiptdate") < lit(date) + lit(interval) -) - -# Restrict to the two ship modes of interest. ``in_list`` maps directly to -# the ``l_shipmode in ('FOB', 'SHIP')`` clause of the reference SQL. -df = df.filter(F.in_list(col("l_shipmode"), [lit(SHIP_MODE_1), lit(SHIP_MODE_2)])) - - -# We need order priority, so join order df to line item -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") - -# Restrict to line items we care about based on the problem statement. -df = df.filter(col("l_commitdate") < col("l_receiptdate")) - -df = df.filter(col("l_shipdate") < col("l_commitdate")) - -df = df.with_column( - "high_line_value", - F.case(col("o_orderpriority")) - .when(lit("1-URGENT"), lit(1)) - .when(lit("2-HIGH"), lit(1)) - .otherwise(lit(0)), -) - -# Aggregate the results +df = df_lineitem.filter( + col("l_receiptdate") >= lit(date), + col("l_receiptdate") < lit(date) + lit(interval), + # ``in_list`` maps directly to ``l_shipmode in (...)`` from the SQL. + F.in_list(col("l_shipmode"), [lit(SHIP_MODE_1), lit(SHIP_MODE_2)]), + col("l_shipdate") < col("l_commitdate"), + col("l_commitdate") < col("l_receiptdate"), +).join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + +# Flag each line item as belonging to a high-priority order or not. +is_high_priority = F.in_list(col("o_orderpriority"), [lit("1-URGENT"), lit("2-HIGH")]) + +# Count the high-priority and low-priority lineitems per ship mode. df = df.aggregate( - [col("l_shipmode")], + ["l_shipmode"], [ - F.sum(col("high_line_value")).alias("high_line_count"), - F.count(col("high_line_value")).alias("all_lines_count"), + F.sum(F.when(is_high_priority, lit(1)).otherwise(lit(0))).alias( + "high_line_count" + ), + F.sum(F.when(~is_high_priority, lit(1)).otherwise(lit(0))).alias( + "low_line_count" + ), ], -) - -# Compute the final output -df = df.select( - col("l_shipmode"), - col("high_line_count"), - (col("all_lines_count") - col("high_line_count")).alias("low_line_count"), -) - -df = df.sort(col("l_shipmode").sort()) +).sort_by("l_shipmode") df.show() diff --git a/examples/tpch/q13_customer_distribution.py b/examples/tpch/q13_customer_distribution.py index 206927727..37c0b93f6 100644 --- a/examples/tpch/q13_customer_distribution.py +++ b/examples/tpch/q13_customer_distribution.py @@ -72,20 +72,16 @@ F.regexp_match(col("o_comment"), lit(f"{WORD_1}.?*{WORD_2}")).is_null() ) -# Since we may have customers with no orders we must do a left join -df = df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="left" -) - -# Find the number of orders for each customer -df = df.aggregate([col("c_custkey")], [F.count(col("o_custkey")).alias("c_count")]) - -# Ultimately we want to know the number of customers that have that customer count -df = df.aggregate([col("c_count")], [F.count(col("c_count")).alias("custdist")]) - -# We want to order the results by the highest number of customers per count -df = df.sort( - col("custdist").sort(ascending=False), col("c_count").sort(ascending=False) +# Customers with no orders still participate, so this is a left join. Count the +# orders per customer, then count customers per order-count value. +df = ( + df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey", how="left") + .aggregate(["c_custkey"], [F.count(col("o_custkey")).alias("c_count")]) + .aggregate(["c_count"], [F.count_star().alias("custdist")]) + .sort( + col("custdist").sort(ascending=False), + col("c_count").sort(ascending=False), + ) ) df.show() diff --git a/examples/tpch/q14_promotion_effect.py b/examples/tpch/q14_promotion_effect.py index 67aa5bd1c..2c020d368 100644 --- a/examples/tpch/q14_promotion_effect.py +++ b/examples/tpch/q14_promotion_effect.py @@ -65,36 +65,31 @@ df_part = ctx.read_parquet(get_data_path("part.parquet")).select("p_partkey", "p_type") -# Check part type begins with PROMO -df_part = df_part.filter( - F.substring(col("p_type"), lit(0), lit(6)) == lit("PROMO") -).with_column("promo_factor", lit(1.0)) - -df_lineitem = df_lineitem.filter(col("l_shipdate") >= date_of_interest).filter( - col("l_shipdate") < date_of_interest + interval_one_month -) - -# Left join so we can sum up the promo parts different from other parts -df = df_lineitem.join( - df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="left" +# Restrict the line items to the month of interest, join the matching part +# rows, and compute revenue per line item. +df = ( + df_lineitem.filter( + col("l_shipdate") >= date_of_interest, + col("l_shipdate") < date_of_interest + interval_one_month, + ) + .join(df_part, left_on="l_partkey", right_on="p_partkey") + .with_column("revenue", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) ) -# Make a factor of 1.0 if it is a promotion, 0.0 otherwise -df = df.with_column("promo_factor", F.coalesce(col("promo_factor"), lit(0.0))) -df = df.with_column("revenue", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) - - -# Sum up the promo and total revenue +# Sum promotional and total revenue, then compute the percentage. The +# ``F.when(...)`` form mirrors the ``case when p_type like 'PROMO%' ... else 0`` +# in the reference SQL. df = df.aggregate( [], [ - F.sum(col("promo_factor") * col("revenue")).alias("promo_revenue"), + F.sum( + F.when( + F.starts_with(col("p_type"), lit("PROMO")), col("revenue") + ).otherwise(lit(0.0)) + ).alias("promo_revenue"), F.sum(col("revenue")).alias("total_revenue"), ], -) - -# Return the percentage of revenue from promotions -df = df.select( +).select( (lit(100.0) * col("promo_revenue") / col("total_revenue")).alias("promo_revenue") ) diff --git a/examples/tpch/q15_top_supplier.py b/examples/tpch/q15_top_supplier.py index 4b1c4c8c8..e32ea65d1 100644 --- a/examples/tpch/q15_top_supplier.py +++ b/examples/tpch/q15_top_supplier.py @@ -88,13 +88,12 @@ "s_phone", ) -# Limit line items to the quarter of interest -df_lineitem = df_lineitem.filter(col("l_shipdate") >= date_of_interest).filter( - col("l_shipdate") < date_of_interest + interval_3_months -) - -df = df_lineitem.aggregate( - [col("l_suppkey")], +# Per-supplier revenue over the quarter of interest. +per_supplier_revenue = df_lineitem.filter( + col("l_shipdate") >= date_of_interest, + col("l_shipdate") < date_of_interest + interval_3_months, +).aggregate( + ["l_suppkey"], [ F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias( "total_revenue" @@ -102,24 +101,20 @@ ], ) -# Use a window function to find the maximum revenue across the entire dataframe -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "max_revenue", - F.max(col("total_revenue")).over(Window(window_frame=window_frame)), +# A window ``max`` over the whole frame acts as a grand maximum that can be +# compared row-by-row — the DataFrame stand-in for the reference SQL's +# ``total_revenue = (select max(total_revenue) from revenue0)`` subquery. +whole_frame = WindowFrame("rows", None, None) + +df = ( + per_supplier_revenue.with_column( + "max_revenue", + F.max(col("total_revenue")).over(Window(window_frame=whole_frame)), + ) + .filter(col("total_revenue") == col("max_revenue")) + .join(df_supplier, left_on="l_suppkey", right_on="s_suppkey") + .select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") + .sort_by("s_suppkey") ) -# Find all suppliers whose total revenue is the same as the maximum -df = df.filter(col("total_revenue") == col("max_revenue")) - -# Now that we know the supplier(s) with maximum revenue, get the rest of their information -# from the supplier table -df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") - -# Return only the columns requested -df = df.select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") - -# If we have more than one, sort by supplier number (suppkey) -df = df.sort(col("s_suppkey").sort()) - df.show() diff --git a/examples/tpch/q16_part_supplier_relationship.py b/examples/tpch/q16_part_supplier_relationship.py index af6e255fd..755ee91bb 100644 --- a/examples/tpch/q16_part_supplier_relationship.py +++ b/examples/tpch/q16_part_supplier_relationship.py @@ -61,7 +61,6 @@ p_size; """ -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path @@ -85,39 +84,36 @@ ) df_unwanted_suppliers = df_supplier.filter( - ~F.regexp_match(col("s_comment"), lit("Customer.?*Complaints")).is_null() + F.regexp_match(col("s_comment"), lit("Customer.?*Complaints")).is_not_null() ) -# Remove unwanted suppliers +# Remove unwanted suppliers via an anti join (DataFrame form of NOT IN). df_partsupp = df_partsupp.join( - df_unwanted_suppliers, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="anti" + df_unwanted_suppliers, left_on="ps_suppkey", right_on="s_suppkey", how="anti" ) -# Select the parts we are interested in -df_part = df_part.filter(col("p_brand") != lit(BRAND)) +# Select the parts we are interested in. df_part = df_part.filter( - F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1)) - != lit(TYPE_TO_IGNORE) + col("p_brand") != BRAND, + ~F.starts_with(col("p_type"), lit(TYPE_TO_IGNORE)), + F.in_list(col("p_size"), [lit(s) for s in SIZES_OF_INTEREST]), ) -# Python conversion of integer to literal casts it to int64 but the data for -# part size is stored as an int32, so perform a cast. Then check to find if the part -# size is within the array of possible sizes by checking the position of it is not -# null. -p_sizes = F.make_array(*[lit(s).cast(pa.int32()) for s in SIZES_OF_INTEREST]) -df_part = df_part.filter(~F.array_position(p_sizes, col("p_size")).is_null()) - -df = df_part.join( - df_partsupp, left_on=["p_partkey"], right_on=["ps_partkey"], how="inner" +# For each (brand, type, size), count the distinct suppliers remaining. +df = ( + df_part.join(df_partsupp, left_on="p_partkey", right_on="ps_partkey") + .select("p_brand", "p_type", "p_size", "ps_suppkey") + .distinct() + .aggregate( + ["p_brand", "p_type", "p_size"], + [F.count(col("ps_suppkey")).alias("supplier_cnt")], + ) + .sort( + col("supplier_cnt").sort(ascending=False), + "p_brand", + "p_type", + "p_size", + ) ) -df = df.select("p_brand", "p_type", "p_size", "ps_suppkey").distinct() - -df = df.aggregate( - [col("p_brand"), col("p_type"), col("p_size")], - [F.count(col("ps_suppkey")).alias("supplier_cnt")], -) - -df = df.sort(col("supplier_cnt").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q17_small_quantity_order.py b/examples/tpch/q17_small_quantity_order.py index 0f0f575b6..f2229171f 100644 --- a/examples/tpch/q17_small_quantity_order.py +++ b/examples/tpch/q17_small_quantity_order.py @@ -67,29 +67,23 @@ "l_partkey", "l_quantity", "l_extendedprice" ) -# Limit to the problem statement's brand and container types -df = df_part.filter(col("p_brand") == lit(BRAND)).filter( - col("p_container") == lit(CONTAINER) +# Limit to parts of the target brand/container, join their line items, and +# attach the per-part average quantity via a partitioned window function — +# the DataFrame form of the SQL's correlated ``avg(l_quantity)`` subquery. +whole_frame = WindowFrame("rows", None, None) + +df = ( + df_part.filter(col("p_brand") == BRAND, col("p_container") == CONTAINER) + .join(df_lineitem, left_on="p_partkey", right_on="l_partkey") + .with_column( + "avg_quantity", + F.avg(col("l_quantity")).over( + Window(partition_by=[col("l_partkey")], window_frame=whole_frame) + ), + ) + .filter(col("l_quantity") < lit(0.2) * col("avg_quantity")) + .aggregate([], [F.sum(col("l_extendedprice")).alias("total")]) + .select((col("total") / lit(7.0)).alias("avg_yearly")) ) -# Combine data -df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner") - -# Find the average quantity -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "avg_quantity", - F.avg(col("l_quantity")).over( - Window(partition_by=[col("l_partkey")], window_frame=window_frame) - ), -) - -df = df.filter(col("l_quantity") < lit(0.2) * col("avg_quantity")) - -# Compute the total -df = df.aggregate([], [F.sum(col("l_extendedprice")).alias("total")]) - -# Divide by number of years in the problem statement to get average -df = df.select((col("total") / lit(7)).alias("avg_yearly")) - df.show() diff --git a/examples/tpch/q18_large_volume_customer.py b/examples/tpch/q18_large_volume_customer.py index 55e3b71f9..23132d60d 100644 --- a/examples/tpch/q18_large_volume_customer.py +++ b/examples/tpch/q18_large_volume_customer.py @@ -61,7 +61,7 @@ o_orderdate limit 100; """ -from datafusion import SessionContext, col, lit +from datafusion import SessionContext, col from datafusion import functions as F from util import get_data_path @@ -81,22 +81,24 @@ "l_orderkey", "l_quantity", "l_extendedprice" ) -df = df_lineitem.aggregate( - [col("l_orderkey")], [F.sum(col("l_quantity")).alias("total_quantity")] +# Find orders whose total quantity exceeds the threshold, then join in the +# order + customer details the problem statement requires and sort. +df = ( + df_lineitem.aggregate( + ["l_orderkey"], [F.sum(col("l_quantity")).alias("total_quantity")] + ) + .filter(col("total_quantity") > QUANTITY) + .join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join(df_customer, left_on="o_custkey", right_on="c_custkey") + .select( + "c_name", + "c_custkey", + "o_orderkey", + "o_orderdate", + "o_totalprice", + "total_quantity", + ) + .sort(col("o_totalprice").sort(ascending=False), "o_orderdate") ) -# Limit to orders in which the total quantity is above a threshold -df = df.filter(col("total_quantity") > lit(QUANTITY)) - -# We've identified the orders of interest, now join the additional data -# we are required to report on -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") -df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") - -df = df.select( - "c_name", "c_custkey", "o_orderkey", "o_orderdate", "o_totalprice", "total_quantity" -) - -df = df.sort(col("o_totalprice").sort(ascending=False), col("o_orderdate").sort()) - df.show() diff --git a/examples/tpch/q19_discounted_revenue.py b/examples/tpch/q19_discounted_revenue.py index 5f87e6aa2..a2be1c1b7 100644 --- a/examples/tpch/q19_discounted_revenue.py +++ b/examples/tpch/q19_discounted_revenue.py @@ -102,15 +102,13 @@ "l_discount", ) -# These limitations apply to all line items, so go ahead and do them first - -df = df_lineitem.filter(col("l_shipinstruct") == lit("DELIVER IN PERSON")) - -df = df.filter( - (col("l_shipmode") == lit("AIR")) | (col("l_shipmode") == lit("AIR REG")) -) - -df = df.join(df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner") +# Filter conditions that apply to every disjunct of the reference SQL's WHERE +# clause — pull them out up front so the per-brand predicate stays focused on +# the brand-specific parts. +df = df_lineitem.filter( + col("l_shipinstruct") == "DELIVER IN PERSON", + F.in_list(col("l_shipmode"), [lit("AIR"), lit("AIR REG")]), +).join(df_part, left_on="l_partkey", right_on="p_partkey") # Build one OR-combined predicate per brand. Each disjunct encodes the @@ -121,12 +119,10 @@ def _brand_predicate( brand: str, min_quantity: int, containers: list[str], max_size: int ): return ( - (col("p_brand") == lit(brand)) + (col("p_brand") == brand) & F.in_list(col("p_container"), [lit(c) for c in containers]) - & (col("l_quantity") >= lit(min_quantity)) - & (col("l_quantity") <= lit(min_quantity + 10)) - & (col("p_size") >= lit(1)) - & (col("p_size") <= lit(max_size)) + & col("l_quantity").between(lit(min_quantity), lit(min_quantity + 10)) + & col("p_size").between(lit(1), lit(max_size)) ) @@ -140,9 +136,7 @@ def _brand_predicate( ) predicate = part_predicate if predicate is None else predicate | part_predicate -df = df.filter(predicate) - -df = df.aggregate( +df = df.filter(predicate).aggregate( [], [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], ) diff --git a/examples/tpch/q20_potential_part_promotion.py b/examples/tpch/q20_potential_part_promotion.py index 47a60fe79..78b65523d 100644 --- a/examples/tpch/q20_potential_part_promotion.py +++ b/examples/tpch/q20_potential_part_promotion.py @@ -103,17 +103,19 @@ # Filter down dataframes. ``starts_with`` reads more naturally than an # explicit substring slice and maps directly to the reference SQL's # ``p_name like 'forest%'`` clause. -df_nation = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST)) +df_nation = df_nation.filter(col("n_name") == NATION_OF_INTEREST) df_part = df_part.filter(F.starts_with(col("p_name"), lit(COLOR_OF_INTEREST))) # Compute the total quantity of interesting parts shipped by each (part, # supplier) pair within the year of interest. totals = ( - df_lineitem.filter(col("l_shipdate") >= lit(date)) - .filter(col("l_shipdate") < lit(date) + lit(interval)) - .join(df_part, left_on="l_partkey", right_on="p_partkey", how="inner") + df_lineitem.filter( + col("l_shipdate") >= lit(date), + col("l_shipdate") < lit(date) + lit(interval), + ) + .join(df_part, left_on="l_partkey", right_on="p_partkey") .aggregate( - [col("l_partkey"), col("l_suppkey")], + ["l_partkey", "l_suppkey"], [F.sum(col("l_quantity")).alias("total_sold")], ) ) @@ -127,7 +129,6 @@ totals, left_on=["ps_partkey", "ps_suppkey"], right_on=["l_partkey", "l_suppkey"], - how="inner", ) .filter(col("ps_availqty") > lit(0.5) * col("total_sold")) .select(col("ps_suppkey").alias("suppkey")) @@ -136,10 +137,11 @@ # Limit to suppliers in the nation of interest and pick out the two # requested columns. -df = df_supplier.join( - df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" -).join(excess_suppliers, left_on="s_suppkey", right_on="suppkey", how="semi") - -df = df.select("s_name", "s_address").sort(col("s_name").sort()) +df = ( + df_supplier.join(df_nation, left_on="s_nationkey", right_on="n_nationkey") + .join(excess_suppliers, left_on="s_suppkey", right_on="suppkey", how="semi") + .select("s_name", "s_address") + .sort_by("s_name") +) df.show() diff --git a/examples/tpch/q21_suppliers_kept_orders_waiting.py b/examples/tpch/q21_suppliers_kept_orders_waiting.py index 7ec5c3069..d98f76ce7 100644 --- a/examples/tpch/q21_suppliers_kept_orders_waiting.py +++ b/examples/tpch/q21_suppliers_kept_orders_waiting.py @@ -68,7 +68,7 @@ s_name limit 100; """ -from datafusion import SessionContext, col, lit +from datafusion import SessionContext, col from datafusion import functions as F from util import get_data_path @@ -92,17 +92,16 @@ ) # Limit to suppliers in the nation of interest -df_suppliers_of_interest = df_nation.filter( - col("n_name") == lit(NATION_OF_INTEREST) -).join(df_supplier, left_on="n_nationkey", right_on="s_nationkey", how="inner") +df_suppliers_of_interest = df_nation.filter(col("n_name") == NATION_OF_INTEREST).join( + df_supplier, left_on="n_nationkey", right_on="s_nationkey" +) # Line items for orders that have status 'F'. This is the candidate set of # (order, supplier) pairs we reason about below. failed_order_lineitems = df_lineitem.join( - df_orders.filter(col("o_orderstatus") == lit("F")), + df_orders.filter(col("o_orderstatus") == "F"), left_on="l_orderkey", right_on="o_orderkey", - how="inner", ) # Line items whose receipt was late. This corresponds to ``l1`` in the @@ -117,8 +116,8 @@ multi_supplier_orders = ( failed_order_lineitems.select("l_orderkey", "l_suppkey") .distinct() - .aggregate([col("l_orderkey")], [F.count(col("l_suppkey")).alias("n_suppliers")]) - .filter(col("n_suppliers") > lit(1)) + .aggregate(["l_orderkey"], [F.count_star().alias("n_suppliers")]) + .filter(col("n_suppliers") > 1) .select("l_orderkey") ) @@ -129,30 +128,20 @@ single_late_supplier_orders = ( late_lineitems.select("l_orderkey", "l_suppkey") .distinct() - .aggregate( - [col("l_orderkey")], [F.count(col("l_suppkey")).alias("n_late_suppliers")] - ) - .filter(col("n_late_suppliers") == lit(1)) + .aggregate(["l_orderkey"], [F.count_star().alias("n_late_suppliers")]) + .filter(col("n_late_suppliers") == 1) .select("l_orderkey") ) -# Keep late line items whose order qualifies on both counts. Semi joins -# preserve the left-side columns without fanning out on the right. -df = late_lineitems.join(multi_supplier_orders, on="l_orderkey", how="semi").join( - single_late_supplier_orders, on="l_orderkey", how="semi" -) - -# Attach the supplier name for suppliers in the nation of interest, count -# one row per qualifying order, and return the top 100. +# Keep late line items whose order qualifies on both counts, attach the +# supplier name for suppliers in the nation of interest, count one row per +# qualifying order, and return the top 100. df = ( - df.join( - df_suppliers_of_interest, - left_on="l_suppkey", - right_on="s_suppkey", - how="inner", - ) - .aggregate([col("s_name")], [F.count(col("l_orderkey")).alias("numwait")]) - .sort(col("numwait").sort(ascending=False), col("s_name").sort()) + late_lineitems.join(multi_supplier_orders, on="l_orderkey", how="semi") + .join(single_late_supplier_orders, on="l_orderkey", how="semi") + .join(df_suppliers_of_interest, left_on="l_suppkey", right_on="s_suppkey") + .aggregate(["s_name"], [F.count_star().alias("numwait")]) + .sort(col("numwait").sort(ascending=False), "s_name") .limit(100) ) diff --git a/examples/tpch/q22_global_sales_opportunity.py b/examples/tpch/q22_global_sales_opportunity.py index 5f463ab45..5043eeb51 100644 --- a/examples/tpch/q22_global_sales_opportunity.py +++ b/examples/tpch/q22_global_sales_opportunity.py @@ -82,40 +82,36 @@ ) df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select("o_custkey") -# The nation code is a two digit number, but we need to convert it to a string literal -nation_codes = F.make_array(*[lit(str(n)) for n in NATION_CODES]) - -# Use the substring operation to extract the first two characters of the phone number -df = df_customer.with_column("cntrycode", F.substring(col("c_phone"), lit(0), lit(3))) - -# Limit our search to customers with some balance and in the country code above -df = df.filter(col("c_acctbal") > lit(0.0)) -df = df.filter(~F.array_position(nation_codes, col("cntrycode")).is_null()) - -# Compute the average balance. By default, the window frame is from unbounded preceding to the -# current row. We want our frame to cover the entire data frame. -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "avg_balance", - F.avg(col("c_acctbal")).over(Window(window_frame=window_frame)), +# Country code is the two-digit prefix of the phone number. +nation_codes = [lit(str(n)) for n in NATION_CODES] + +# Start from customers with a positive balance in one of the target country +# codes, then attach the grand-mean balance via a whole-frame window so we +# can filter per row — DataFrame stand-in for the SQL's scalar ``(select +# avg(c_acctbal) ... )`` subquery. +whole_frame = WindowFrame("rows", None, None) + +df = ( + df_customer.with_column("cntrycode", F.left(col("c_phone"), lit(2))) + .filter( + col("c_acctbal") > 0.0, + F.in_list(col("cntrycode"), nation_codes), + ) + .with_column( + "avg_balance", + F.avg(col("c_acctbal")).over(Window(window_frame=whole_frame)), + ) + .filter(col("c_acctbal") > col("avg_balance")) + # Keep only customers with no orders (anti join = NOT EXISTS). + .join(df_orders, left_on="c_custkey", right_on="o_custkey", how="anti") + .aggregate( + ["cntrycode"], + [ + F.count_star().alias("numcust"), + F.sum(col("c_acctbal")).alias("totacctbal"), + ], + ) + .sort_by("cntrycode") ) df.show() -# Limit results to customers with above average balance -df = df.filter(col("c_acctbal") > col("avg_balance")) - -# Limit results to customers with no orders -df = df.join(df_orders, left_on="c_custkey", right_on="o_custkey", how="anti") - -# Count up the customers and the balances -df = df.aggregate( - [col("cntrycode")], - [ - F.count(col("c_custkey")).alias("numcust"), - F.sum(col("c_acctbal")).alias("totacctbal"), - ], -) - -df = df.sort(col("cntrycode").sort()) - -df.show() From 0d16d177fa87837f3041bfe4a9d7c0e27bf699da Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 24 Apr 2026 10:21:01 -0400 Subject: [PATCH 5/5] tpch examples: more idiomatic aggregate FILTER, string funcs, date handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Additional sweep of the TPC-H DataFrame examples informed by comparing against a fresh set of SKILL.md-only generations under ``examples/tpch/agentic_queries/``: - Q02: ``F.ends_with(col("p_type"), lit(TYPE_OF_INTEREST))`` in place of ``F.strpos(col, lit) > 0``. The reference SQL is ``p_type like '%BRASS'``, which is an ends_with check, not contains. ``F.strpos > 0`` returned the correct rows on TPC-H data by coincidence but is semantically wrong. - Q09: ``F.contains(col("p_name"), lit(part_color))`` in place of ``F.strpos(col, lit) > 0``. The SQL is ``p_name like '%green%'``. - Q08, Q12, Q14: use the ``filter`` keyword on ``F.sum`` / ``F.count`` — the DataFrame form of SQL ``sum(...) FILTER (WHERE ...)`` — instead of wrapping the aggregate input in ``F.when(cond, x).otherwise(0)``. Q08 also reorganises to inner-join the supplier's nation onto the regional sales, which removes the previous left-join + ``F.when(is_not_null, ...)`` dance. - Q15: compute the grand maximum revenue as a separate scalar aggregate and ``join_on(...)`` on equality, instead of the whole-frame window ``F.max`` + filter shape. Simpler plan, same result. - Q16: ``F.regexp_like(col, pattern)`` in place of ``F.regexp_match(col, pattern).is_not_null()``. - Q04, Q05, Q06, Q07, Q08, Q10, Q12, Q14, Q15, Q20: store both the start and the end of the date window as plain ``datetime.date`` objects and compare with ``lit(end_date)``, instead of carrying the start date + ``pa.month_day_nano_interval`` and adding them at query-build time. Drops unused ``pyarrow`` imports from the files that no longer need Arrow scalars. All 22 answer-file comparisons still pass at scale factor 1. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/tpch/q02_minimum_cost_supplier.py | 8 +- examples/tpch/q04_order_priority_checking.py | 17 ++-- examples/tpch/q05_local_supplier_volume.py | 15 ++-- .../tpch/q06_forecasting_revenue_change.py | 16 ++-- examples/tpch/q07_volume_shipping.py | 11 +-- examples/tpch/q08_market_share.py | 87 ++++++++----------- .../tpch/q09_product_type_profit_measure.py | 7 +- examples/tpch/q10_returned_item_reporting.py | 14 ++- examples/tpch/q12_ship_mode_order_priority.py | 30 +++---- examples/tpch/q14_promotion_effect.py | 50 +++++------ examples/tpch/q15_top_supplier.py | 47 ++++------ .../tpch/q16_part_supplier_relationship.py | 2 +- examples/tpch/q20_potential_part_promotion.py | 14 ++- 13 files changed, 128 insertions(+), 190 deletions(-) diff --git a/examples/tpch/q02_minimum_cost_supplier.py b/examples/tpch/q02_minimum_cost_supplier.py index 6c26d262e..c5c6b9c0b 100644 --- a/examples/tpch/q02_minimum_cost_supplier.py +++ b/examples/tpch/q02_minimum_cost_supplier.py @@ -113,12 +113,12 @@ "r_regionkey", "r_name" ) -# Filter down parts. Part names contain the type of interest, so we can use strpos to find where -# in the p_type column the word is. `strpos` will return 0 if not found, otherwise the position -# in the string where it is located. +# Filter down parts. The reference SQL uses ``p_type like '%BRASS'`` which +# is an ``ends_with`` check; use the dedicated string function rather than +# a manual substring match. df_part = df_part.filter( - F.strpos(col("p_type"), lit(TYPE_OF_INTEREST)) > 0, + F.ends_with(col("p_type"), lit(TYPE_OF_INTEREST)), col("p_size") == SIZE_OF_INTEREST, ) diff --git a/examples/tpch/q04_order_priority_checking.py b/examples/tpch/q04_order_priority_checking.py index 18cbb2054..6f11c1383 100644 --- a/examples/tpch/q04_order_priority_checking.py +++ b/examples/tpch/q04_order_priority_checking.py @@ -50,16 +50,14 @@ o_orderpriority; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -# Ideally we could put 3 months into the interval. See note below. -INTERVAL_DAYS = 92 -DATE_OF_INTEREST = "1993-07-01" +QUARTER_START = date(1993, 7, 1) +QUARTER_END = date(1993, 10, 1) # Load the dataframes we need @@ -72,17 +70,12 @@ "l_orderkey", "l_commitdate", "l_receiptdate" ) -# Create a date object from the string -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - # Keep only orders in the quarter of interest, then restrict to those that # have at least one late lineitem via a semi join (the DataFrame form of # ``EXISTS`` from the reference SQL). df_orders = df_orders.filter( - col("o_orderdate") >= lit(date), - col("o_orderdate") < lit(date) + lit(interval), + col("o_orderdate") >= lit(QUARTER_START), + col("o_orderdate") < lit(QUARTER_END), ) late_lineitems = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")) diff --git a/examples/tpch/q05_local_supplier_volume.py b/examples/tpch/q05_local_supplier_volume.py index 5e648f272..bfdba5d4c 100644 --- a/examples/tpch/q05_local_supplier_volume.py +++ b/examples/tpch/q05_local_supplier_volume.py @@ -56,21 +56,16 @@ revenue desc; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE_OF_INTEREST = "1994-01-01" -INTERVAL_DAYS = 365 +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) REGION_OF_INTEREST = "ASIA" -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - # Load the dataframes we need ctx = SessionContext() @@ -96,8 +91,8 @@ # Restrict dataframes to cases of interest df_orders = df_orders.filter( - col("o_orderdate") >= lit(date), - col("o_orderdate") < lit(date) + lit(interval), + col("o_orderdate") >= lit(YEAR_START), + col("o_orderdate") < lit(YEAR_END), ) df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST) diff --git a/examples/tpch/q06_forecasting_revenue_change.py b/examples/tpch/q06_forecasting_revenue_change.py index 79697f833..ed54d22a4 100644 --- a/examples/tpch/q06_forecasting_revenue_change.py +++ b/examples/tpch/q06_forecasting_revenue_change.py @@ -41,26 +41,20 @@ and l_quantity < 24; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path # Variables from the example query -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) DISCOUT = 0.06 DELTA = 0.01 QUANTITY = 24 -INTERVAL_DAYS = 365 - -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - # Load the dataframes we need ctx = SessionContext() @@ -72,8 +66,8 @@ # Filter down to lineitems of interest df = df_lineitem.filter( - col("l_shipdate") >= lit(date), - col("l_shipdate") < lit(date) + lit(interval), + col("l_shipdate") >= lit(YEAR_START), + col("l_shipdate") < lit(YEAR_END), col("l_discount").between(lit(DISCOUT - DELTA), lit(DISCOUT + DELTA)), col("l_quantity") < QUANTITY, ) diff --git a/examples/tpch/q07_volume_shipping.py b/examples/tpch/q07_volume_shipping.py index 6584509cc..df1c2ae0d 100644 --- a/examples/tpch/q07_volume_shipping.py +++ b/examples/tpch/q07_volume_shipping.py @@ -70,7 +70,7 @@ l_year; """ -from datetime import datetime +from datetime import date import pyarrow as pa from datafusion import SessionContext, col, lit @@ -82,11 +82,8 @@ nation_1 = lit("FRANCE") nation_2 = lit("GERMANY") -START_DATE = "1995-01-01" -END_DATE = "1996-12-31" - -start_date = lit(datetime.strptime(START_DATE, "%Y-%m-%d").date()) -end_date = lit(datetime.strptime(END_DATE, "%Y-%m-%d").date()) +START_DATE = date(1995, 1, 1) +END_DATE = date(1996, 12, 31) # Load the dataframes we need @@ -112,7 +109,7 @@ # Filter to time of interest df_lineitem = df_lineitem.filter( - col("l_shipdate") >= start_date, col("l_shipdate") <= end_date + col("l_shipdate") >= lit(START_DATE), col("l_shipdate") <= lit(END_DATE) ) diff --git a/examples/tpch/q08_market_share.py b/examples/tpch/q08_market_share.py index fbf0fc8e5..dd7bacedb 100644 --- a/examples/tpch/q08_market_share.py +++ b/examples/tpch/q08_market_share.py @@ -67,22 +67,19 @@ o_year; """ -from datetime import datetime +from datetime import date import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -supplier_nation = lit("BRAZIL") -customer_region = lit("AMERICA") -part_of_interest = lit("ECONOMY ANODIZED STEEL") +supplier_nation = "BRAZIL" +customer_region = "AMERICA" +part_of_interest = "ECONOMY ANODIZED STEEL" -START_DATE = "1995-01-01" -END_DATE = "1996-12-31" - -start_date = lit(datetime.strptime(START_DATE, "%Y-%m-%d").date()) -end_date = lit(datetime.strptime(END_DATE, "%Y-%m-%d").date()) +START_DATE = date(1995, 1, 1) +END_DATE = date(1996, 12, 31) # Load the dataframes we need @@ -115,67 +112,55 @@ # Limit orders to those in the specified range df_orders = df_orders.filter( - col("o_orderdate") >= start_date, col("o_orderdate") <= end_date + col("o_orderdate") >= lit(START_DATE), col("o_orderdate") <= lit(END_DATE) ) -# Part 1: Find customers in the region +# Pair each supplier with its nation name so every regional-customer row +# below carries the supplier's nation and can be filtered inside the +# aggregate with ``F.sum(..., filter=...)``. -# We want customers in region specified by region_of_interest. This will be used to compute -# the total sales of the part of interest. We want to know of those sales what fraction -# was supplied by the nation of interest. There is no guarantee that the nation of -# interest is within the region of interest. +df_supplier_with_nation = df_supplier.join( + df_nation, left_on="s_nationkey", right_on="n_nationkey" +).select("s_suppkey", col("n_name").alias("supp_nation")) -# First we find all the sales that make up the basis. +# Build every (part, lineitem, order, customer) row for customers in the +# target region ordering the target part. Each row carries the supplier's +# nation so we can aggregate on it below. -df_regional_customers = ( +df = ( df_region.filter(col("r_name") == customer_region) .join(df_nation, left_on="r_regionkey", right_on="n_regionkey") .join(df_customer, left_on="n_nationkey", right_on="c_nationkey") .join(df_orders, left_on="c_custkey", right_on="o_custkey") .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") .join(df_part, left_on="l_partkey", right_on="p_partkey") - .with_column("volume", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) -) - -# Part 2: Find suppliers from the nation - -# Now that we have all of the sales of that part in the specified region, we need -# to determine which of those came from suppliers in the nation we are interested in. - -df_national_suppliers = ( - df_nation.filter(col("n_name") == supplier_nation) - .join(df_supplier, left_on="n_nationkey", right_on="s_nationkey") - .select("s_suppkey") -) - - -# Part 3: Combine suppliers and customers and compute the market share - -# Left-outer join the national suppliers onto the regional sales. Rows from -# other suppliers get a NULL ``s_suppkey``, which the CASE expression uses -# to zero out the non-national volume. - -df = df_regional_customers.join( - df_national_suppliers, left_on="l_suppkey", right_on="s_suppkey", how="left" -).with_columns( - national_volume=F.when(col("s_suppkey").is_not_null(), col("volume")).otherwise( - lit(0.0) - ), - o_year=F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()), + .join(df_supplier_with_nation, left_on="l_suppkey", right_on="s_suppkey") + .with_columns( + volume=col("l_extendedprice") * (lit(1.0) - col("l_discount")), + o_year=F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()), + ) ) - -# Aggregate, compute the share, and sort. - +# Aggregate the total and national volumes per year via the ``filter`` +# kwarg on ``F.sum`` (DataFrame form of SQL ``sum(... ) FILTER (WHERE ...)``). +# ``coalesce`` handles the case where no sale came from the target nation +# for a given year. df = ( df.aggregate( ["o_year"], [ - F.sum(col("volume")).alias("volume"), - F.sum(col("national_volume")).alias("national_volume"), + F.sum(col("volume"), filter=col("supp_nation") == supplier_nation).alias( + "national_volume" + ), + F.sum(col("volume")).alias("total_volume"), ], ) - .select("o_year", (col("national_volume") / col("volume")).alias("mkt_share")) + .select( + "o_year", + (F.coalesce(col("national_volume"), lit(0.0)) / col("total_volume")).alias( + "mkt_share" + ), + ) .sort_by("o_year") ) diff --git a/examples/tpch/q09_product_type_profit_measure.py b/examples/tpch/q09_product_type_profit_measure.py index 139c0483f..ec68a2ab7 100644 --- a/examples/tpch/q09_product_type_profit_measure.py +++ b/examples/tpch/q09_product_type_profit_measure.py @@ -69,7 +69,7 @@ from datafusion import functions as F from util import get_data_path -part_color = lit("green") +part_color = "green" # Load the dataframes we need @@ -98,9 +98,10 @@ ) # Limit possible parts to the color specified, then walk the joins down to the -# line-item rows we need and attach the supplier's nation. +# line-item rows we need and attach the supplier's nation. ``F.contains`` +# maps directly to the reference SQL's ``p_name like '%green%'``. df = ( - df_part.filter(F.strpos(col("p_name"), part_color) > 0) + df_part.filter(F.contains(col("p_name"), lit(part_color))) .join(df_lineitem, left_on="p_partkey", right_on="l_partkey") .join(df_supplier, left_on="l_suppkey", right_on="s_suppkey") .join(df_orders, left_on="l_orderkey", right_on="o_orderkey") diff --git a/examples/tpch/q10_returned_item_reporting.py b/examples/tpch/q10_returned_item_reporting.py index 3e4e8a4d9..e6532517e 100644 --- a/examples/tpch/q10_returned_item_reporting.py +++ b/examples/tpch/q10_returned_item_reporting.py @@ -63,18 +63,14 @@ revenue desc limit 20; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE_START_OF_QUARTER = "1993-10-01" - -date_start_of_quarter = lit(datetime.strptime(DATE_START_OF_QUARTER, "%Y-%m-%d").date()) - -interval_one_quarter = lit(pa.scalar((0, 92, 0), type=pa.month_day_nano_interval())) +QUARTER_START = date(1993, 10, 1) +QUARTER_END = date(1994, 1, 1) # Load the dataframes we need @@ -108,8 +104,8 @@ df = ( df_orders.filter( - col("o_orderdate") >= date_start_of_quarter, - col("o_orderdate") < date_start_of_quarter + interval_one_quarter, + col("o_orderdate") >= lit(QUARTER_START), + col("o_orderdate") < lit(QUARTER_END), ) .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") .aggregate( diff --git a/examples/tpch/q12_ship_mode_order_priority.py b/examples/tpch/q12_ship_mode_order_priority.py index c684e1ba5..fb78fe3c2 100644 --- a/examples/tpch/q12_ship_mode_order_priority.py +++ b/examples/tpch/q12_ship_mode_order_priority.py @@ -60,16 +60,16 @@ l_shipmode; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path SHIP_MODE_1 = "MAIL" SHIP_MODE_2 = "SHIP" -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) # Load the dataframes we need @@ -82,14 +82,10 @@ "l_orderkey", "l_shipmode", "l_commitdate", "l_shipdate", "l_receiptdate" ) -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, 365, 0), type=pa.month_day_nano_interval()) - df = df_lineitem.filter( - col("l_receiptdate") >= lit(date), - col("l_receiptdate") < lit(date) + lit(interval), + col("l_receiptdate") >= lit(YEAR_START), + col("l_receiptdate") < lit(YEAR_END), # ``in_list`` maps directly to ``l_shipmode in (...)`` from the SQL. F.in_list(col("l_shipmode"), [lit(SHIP_MODE_1), lit(SHIP_MODE_2)]), col("l_shipdate") < col("l_commitdate"), @@ -97,18 +93,18 @@ ).join(df_orders, left_on="l_orderkey", right_on="o_orderkey") # Flag each line item as belonging to a high-priority order or not. -is_high_priority = F.in_list(col("o_orderpriority"), [lit("1-URGENT"), lit("2-HIGH")]) +high_priorities = [lit("1-URGENT"), lit("2-HIGH")] +is_high = F.in_list(col("o_orderpriority"), high_priorities) +is_low = F.in_list(col("o_orderpriority"), high_priorities, negated=True) -# Count the high-priority and low-priority lineitems per ship mode. +# Count the high-priority and low-priority lineitems per ship mode via the +# ``filter`` kwarg on ``F.count`` (DataFrame form of SQL's ``count(*) +# FILTER (WHERE ...)``). df = df.aggregate( ["l_shipmode"], [ - F.sum(F.when(is_high_priority, lit(1)).otherwise(lit(0))).alias( - "high_line_count" - ), - F.sum(F.when(~is_high_priority, lit(1)).otherwise(lit(0))).alias( - "low_line_count" - ), + F.count(col("o_orderkey"), filter=is_high).alias("high_line_count"), + F.count(col("o_orderkey"), filter=is_low).alias("low_line_count"), ], ).sort_by("l_shipmode") diff --git a/examples/tpch/q14_promotion_effect.py b/examples/tpch/q14_promotion_effect.py index 2c020d368..08f4f054d 100644 --- a/examples/tpch/q14_promotion_effect.py +++ b/examples/tpch/q14_promotion_effect.py @@ -42,18 +42,14 @@ and l_shipdate < date '1995-09-01' + interval '1' month; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE = "1995-09-01" - -date_of_interest = lit(datetime.strptime(DATE, "%Y-%m-%d").date()) - -interval_one_month = lit(pa.scalar((0, 30, 0), type=pa.month_day_nano_interval())) +MONTH_START = date(1995, 9, 1) +MONTH_END = date(1995, 10, 1) # Load the dataframes we need @@ -66,31 +62,29 @@ # Restrict the line items to the month of interest, join the matching part -# rows, and compute revenue per line item. +# rows, and aggregate revenue totals with a ``filter`` clause on the promo +# sum — the DataFrame form of SQL ``sum(... ) FILTER (WHERE ...)``. +revenue = col("l_extendedprice") * (lit(1.0) - col("l_discount")) +is_promo = F.starts_with(col("p_type"), lit("PROMO")) + df = ( df_lineitem.filter( - col("l_shipdate") >= date_of_interest, - col("l_shipdate") < date_of_interest + interval_one_month, + col("l_shipdate") >= lit(MONTH_START), + col("l_shipdate") < lit(MONTH_END), ) .join(df_part, left_on="l_partkey", right_on="p_partkey") - .with_column("revenue", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) -) - -# Sum promotional and total revenue, then compute the percentage. The -# ``F.when(...)`` form mirrors the ``case when p_type like 'PROMO%' ... else 0`` -# in the reference SQL. -df = df.aggregate( - [], - [ - F.sum( - F.when( - F.starts_with(col("p_type"), lit("PROMO")), col("revenue") - ).otherwise(lit(0.0)) - ).alias("promo_revenue"), - F.sum(col("revenue")).alias("total_revenue"), - ], -).select( - (lit(100.0) * col("promo_revenue") / col("total_revenue")).alias("promo_revenue") + .aggregate( + [], + [ + F.sum(revenue, filter=is_promo).alias("promo_revenue"), + F.sum(revenue).alias("total_revenue"), + ], + ) + .select( + (lit(100.0) * col("promo_revenue") / col("total_revenue")).alias( + "promo_revenue" + ) + ) ) df.show() diff --git a/examples/tpch/q15_top_supplier.py b/examples/tpch/q15_top_supplier.py index e32ea65d1..01c38b9f8 100644 --- a/examples/tpch/q15_top_supplier.py +++ b/examples/tpch/q15_top_supplier.py @@ -60,19 +60,14 @@ drop view revenue0; """ -from datetime import datetime +from datetime import date -import pyarrow as pa -from datafusion import SessionContext, WindowFrame, col, lit +from datafusion import SessionContext, col, lit from datafusion import functions as F -from datafusion.expr import Window from util import get_data_path -DATE = "1996-01-01" - -date_of_interest = lit(datetime.strptime(DATE, "%Y-%m-%d").date()) - -interval_3_months = lit(pa.scalar((0, 91, 0), type=pa.month_day_nano_interval())) +QUARTER_START = date(1996, 1, 1) +QUARTER_END = date(1996, 4, 1) # Load the dataframes we need @@ -89,30 +84,26 @@ ) # Per-supplier revenue over the quarter of interest. +revenue = col("l_extendedprice") * (lit(1) - col("l_discount")) + per_supplier_revenue = df_lineitem.filter( - col("l_shipdate") >= date_of_interest, - col("l_shipdate") < date_of_interest + interval_3_months, -).aggregate( - ["l_suppkey"], - [ - F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias( - "total_revenue" - ) - ], -) + col("l_shipdate") >= lit(QUARTER_START), + col("l_shipdate") < lit(QUARTER_END), +).aggregate(["l_suppkey"], [F.sum(revenue).alias("total_revenue")]) -# A window ``max`` over the whole frame acts as a grand maximum that can be -# compared row-by-row — the DataFrame stand-in for the reference SQL's +# Compute the grand maximum revenue separately and join on equality — the +# DataFrame stand-in for the reference SQL's # ``total_revenue = (select max(total_revenue) from revenue0)`` subquery. -whole_frame = WindowFrame("rows", None, None) +max_revenue = per_supplier_revenue.aggregate( + [], [F.max(col("total_revenue")).alias("max_rev")] +) + +top_suppliers = per_supplier_revenue.join_on( + max_revenue, col("total_revenue") == col("max_rev") +).select("l_suppkey", "total_revenue") df = ( - per_supplier_revenue.with_column( - "max_revenue", - F.max(col("total_revenue")).over(Window(window_frame=whole_frame)), - ) - .filter(col("total_revenue") == col("max_revenue")) - .join(df_supplier, left_on="l_suppkey", right_on="s_suppkey") + df_supplier.join(top_suppliers, left_on="s_suppkey", right_on="l_suppkey") .select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") .sort_by("s_suppkey") ) diff --git a/examples/tpch/q16_part_supplier_relationship.py b/examples/tpch/q16_part_supplier_relationship.py index 755ee91bb..ddeadff5f 100644 --- a/examples/tpch/q16_part_supplier_relationship.py +++ b/examples/tpch/q16_part_supplier_relationship.py @@ -84,7 +84,7 @@ ) df_unwanted_suppliers = df_supplier.filter( - F.regexp_match(col("s_comment"), lit("Customer.?*Complaints")).is_not_null() + F.regexp_like(col("s_comment"), lit("Customer.*Complaints")) ) # Remove unwanted suppliers via an anti join (DataFrame form of NOT IN). diff --git a/examples/tpch/q20_potential_part_promotion.py b/examples/tpch/q20_potential_part_promotion.py index 78b65523d..18f96da97 100644 --- a/examples/tpch/q20_potential_part_promotion.py +++ b/examples/tpch/q20_potential_part_promotion.py @@ -67,15 +67,15 @@ s_name; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path COLOR_OF_INTEREST = "forest" -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) NATION_OF_INTEREST = "CANADA" # Load the dataframes we need @@ -96,10 +96,6 @@ "n_nationkey", "n_name" ) -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, 365, 0), type=pa.month_day_nano_interval()) - # Filter down dataframes. ``starts_with`` reads more naturally than an # explicit substring slice and maps directly to the reference SQL's # ``p_name like 'forest%'`` clause. @@ -110,8 +106,8 @@ # supplier) pair within the year of interest. totals = ( df_lineitem.filter( - col("l_shipdate") >= lit(date), - col("l_shipdate") < lit(date) + lit(interval), + col("l_shipdate") >= lit(YEAR_START), + col("l_shipdate") < lit(YEAR_END), ) .join(df_part, left_on="l_partkey", right_on="p_partkey") .aggregate(