diff --git a/src/Backend/opti-sql-go/Expr/expr.go b/src/Backend/opti-sql-go/Expr/expr.go index b3eed34..f9d88de 100644 --- a/src/Backend/opti-sql-go/Expr/expr.go +++ b/src/Backend/opti-sql-go/Expr/expr.go @@ -495,7 +495,6 @@ func EvalBinary(b *BinaryExpr, batch *operators.RecordBatch) (arrow.Array, error return unpackDatum(datum) case Like: if leftArr.DataType() != arrow.BinaryTypes.String || rightArr.DataType() != arrow.BinaryTypes.String { - // regEx runs only on strings return nil, errors.New("binary operator Like only works on arrays of strings") } var compiledRegEx = compileSqlRegEx(rightArr.ValueStr(0)) @@ -503,7 +502,6 @@ func EvalBinary(b *BinaryExpr, batch *operators.RecordBatch) (arrow.Array, error leftStrArray := leftArr.(*array.String) for i := 0; i < leftStrArray.Len(); i++ { valid := validRegEx(leftStrArray.Value(i), compiledRegEx) - fmt.Printf("does %s match %s: %v\n", leftStrArray.Value(i), compiledRegEx, valid) filterBuilder.Append(valid) } return filterBuilder.NewArray(), nil @@ -602,7 +600,6 @@ func EvalCast(c *CastExpr, batch *operators.RecordBatch) (arrow.Array, error) { castOpts := compute.SafeCastOptions(c.TargetType) out, err := compute.CastArray(context.TODO(), arr, castOpts) if err != nil { - // This is a runtime cast error return nil, fmt.Errorf("cast error: cannot cast %s to %s: %w", arr.DataType(), c.TargetType, err) } diff --git a/src/Backend/opti-sql-go/go.mod b/src/Backend/opti-sql-go/go.mod index c9ee239..5b872b6 100644 --- a/src/Backend/opti-sql-go/go.mod +++ b/src/Backend/opti-sql-go/go.mod @@ -1,6 +1,6 @@ module opti-sql-go -go 1.23 +go 1.24.0 require ( github.com/apache/arrow/go/v15 v15.0.2 @@ -28,6 +28,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.13 // indirect github.com/aws/smithy-go v1.23.2 // indirect github.com/go-ini/ini v1.67.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect diff --git a/src/Backend/opti-sql-go/go.sum b/src/Backend/opti-sql-go/go.sum index 9c4220d..7c4ee5c 100644 --- a/src/Backend/opti-sql-go/go.sum +++ b/src/Backend/opti-sql-go/go.sum @@ -37,6 +37,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= diff --git a/src/Backend/opti-sql-go/operators/aggr/avgExec.go b/src/Backend/opti-sql-go/operators/aggr/avgExec.go deleted file mode 100644 index abd1ad5..0000000 --- a/src/Backend/opti-sql-go/operators/aggr/avgExec.go +++ /dev/null @@ -1 +0,0 @@ -package aggr diff --git a/src/Backend/opti-sql-go/operators/aggr/avgExec_test.go b/src/Backend/opti-sql-go/operators/aggr/avgExec_test.go deleted file mode 100644 index 67671d0..0000000 --- a/src/Backend/opti-sql-go/operators/aggr/avgExec_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package aggr - -import "testing" - -func TestAvgExec(t *testing.T) { - // Simple passing test -} diff --git a/src/Backend/opti-sql-go/operators/aggr/basicAggr.go b/src/Backend/opti-sql-go/operators/aggr/basicAggr.go deleted file mode 100644 index 0ffa1f3..0000000 --- a/src/Backend/opti-sql-go/operators/aggr/basicAggr.go +++ /dev/null @@ -1,5 +0,0 @@ -package aggr - -// Min -//Max -//Count diff --git a/src/Backend/opti-sql-go/operators/aggr/basicAggr_test.go b/src/Backend/opti-sql-go/operators/aggr/basicAggr_test.go deleted file mode 100644 index 7a59206..0000000 --- a/src/Backend/opti-sql-go/operators/aggr/basicAggr_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package aggr - -import "testing" - -func TestBasicAggr(t *testing.T) { - // Simple passing test -} diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy.go b/src/Backend/opti-sql-go/operators/aggr/groupBy.go index abd1ad5..962a450 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy.go @@ -1 +1,442 @@ package aggr + +import ( + "errors" + "fmt" + "io" + "opti-sql-go/Expr" + "opti-sql-go/operators" + "strings" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" +) + +/* +rules for group by: +1.Every non-aggregated column in SELECT must be in GROUP BY +2.You can group by multiple columns - creates groups for each unique combination +3.Use HAVING to filter groups (WHERE filters before grouping, HAVING filters after) +*/ +var ( + _ = (operators.Operator)(&GroupByExec{}) +) + +// place all unique elements of the group by column into a hash table, each element gets their own Accumulator instance +type GroupByExec struct { + input operators.Operator + schema *arrow.Schema + groupExpr []AggregateFunctions + groupByExpr []Expr.Expression // column names + + groups map[string][]accumulator // maps group by key to its accumulator + keys map[string][]string // key → original values for output + done bool +} + +func NewGroupByExec(child operators.Operator, groupExpr []AggregateFunctions, groupBy []Expr.Expression) (*GroupByExec, error) { + s, err := buildGroupBySchema(child.Schema(), groupBy, groupExpr) + if err != nil { + return nil, err + } + + return &GroupByExec{ + input: child, + schema: s, + groupExpr: groupExpr, + groupByExpr: groupBy, + keys: make(map[string][]string), + groups: make(map[string][]accumulator), + }, nil +} + +/* +grab child rows +*/ +func (g *GroupByExec) Next(batchSize uint16) (*operators.RecordBatch, error) { + if g.done { + return nil, io.EOF + } + + for { + childBatch, err := g.input.Next(batchSize) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, err + } + + rowCount := int(childBatch.RowCount) + + // 1. evaluate all group-by expressions into arrays + groupArrays := make([]arrow.Array, len(g.groupByExpr)) + for i, expr := range g.groupByExpr { + arr, err := Expr.EvalExpression(expr, childBatch) + if err != nil { + operators.ReleaseArrays(groupArrays) + operators.ReleaseArrays(childBatch.Columns) + return nil, err + } + groupArrays[i] = arr + } + + // 2. evaluate all aggregation child expressions + aggrArrays := make([]arrow.Array, len(g.groupExpr)) + for i, agg := range g.groupExpr { + arr, err := Expr.EvalExpression(agg.Child, childBatch) + if err != nil { + operators.ReleaseArrays(aggrArrays) + operators.ReleaseArrays(groupArrays) + operators.ReleaseArrays(childBatch.Columns) + return nil, err + } + arr, err = castArrayToFloat64(arr) + if err != nil { + operators.ReleaseArrays(aggrArrays) + operators.ReleaseArrays(groupArrays) + operators.ReleaseArrays(childBatch.Columns) + return nil, err + } + aggrArrays[i] = arr + } + + // 3. process rows + for row := 0; row < rowCount; row++ { + + // Build group key + keyParts := make([]string, len(groupArrays)) + for j, arr := range groupArrays { + if arr.IsNull(row) { + keyParts[j] = "NULL" + } else { + keyParts[j] = fmt.Sprintf("%v", getValue(arr, row)) + } + } + key := strings.Join(keyParts, "|") + // Allocate accumulator list if new group + if _, exists := g.groups[key]; !exists { + g.groups[key] = make([]accumulator, len(g.groupExpr)) + for i, agg := range g.groupExpr { + g.groups[key][i] = createAccumulator(agg.AggrFunc) + } + g.keys[key] = keyParts // store original values + } + + // UPDATE accumulators + for i, arr := range aggrArrays { + if arr.IsNull(row) { + continue + } + val := arr.(*array.Float64).Value(row) + g.groups[key][i].Update(val) + } + } + // 4. release temp arrays + operators.ReleaseArrays(aggrArrays) + operators.ReleaseArrays(groupArrays) + operators.ReleaseArrays(childBatch.Columns) + } + + // 4. Build output RecordBatch + batch := buildGroupByOutput(g) + + g.done = true + return batch, nil +} + +func (g *GroupByExec) Schema() *arrow.Schema { + return g.schema +} +func (g *GroupByExec) Close() error { + return g.input.Close() +} + +// handles validation and building of schema for group by +func buildGroupBySchema(childSchema *arrow.Schema, groupByExpr []Expr.Expression, aggrExprs []AggregateFunctions) (*arrow.Schema, error) { + + fields := make([]arrow.Field, 0, len(groupByExpr)+len(aggrExprs)) + + // 1. Add group-by columns + for _, expr := range groupByExpr { + dt, err := Expr.ExprDataType(expr, childSchema) + if err != nil { + return nil, fmt.Errorf("group-by expr %s has invalid type: %w", expr.String(), err) + } + + fields = append(fields, arrow.Field{ + Name: fmt.Sprintf("group_%s", expr.String()), + Type: dt, + Nullable: true, + }) + } + + // 2. Add aggregate columns + for _, agg := range aggrExprs { + dt, err := Expr.ExprDataType(agg.Child, childSchema) + if err != nil || !validAggrType(dt) { + return nil, ErrInvalidAggrColumnType(dt) + } + // All aggregates produce float64 + fieldName := fmt.Sprintf("%s_%s", + strings.ToLower(aggrToString(int(agg.AggrFunc))), + agg.Child.String(), + ) + + fields = append(fields, arrow.Field{ + Name: fieldName, + Type: arrow.PrimitiveTypes.Float64, + Nullable: false, + }) + } + + return arrow.NewSchema(fields, nil), nil +} + +func getValue(arr arrow.Array, row int) any { + switch col := arr.(type) { + case *array.Int32: + return col.Value(row) + case *array.Int64: + return col.Value(row) + case *array.Float32: + return col.Value(row) + case *array.Float64: + return col.Value(row) + case *array.String: + return col.Value(row) + case *array.Boolean: + return col.Value(row) + default: + // fallback – debug only + return fmt.Sprintf("%v", col) + } +} +func createAccumulator(fn AggrFunc) accumulator { + switch fn { + case Min: + return newMinAggr() + case Max: + return newMaxAggr() + case Sum: + return newSumAggr() + case Count: + return newCountAggr() + case Avg: + return newAvgAggr() + default: + panic(fmt.Sprintf("unsupported aggregate function: %v", fn)) + } +} + +func buildGroupByOutput(g *GroupByExec) *operators.RecordBatch { + alloc := memory.NewGoAllocator() + + rowCount := len(g.groups) + if rowCount == 0 { + return &operators.RecordBatch{ + Schema: g.schema, + Columns: []arrow.Array{}, + RowCount: 0, + } + } + + // Prepare column builders + colBuilders := make([]arrow.Array, len(g.schema.Fields())) + + // Temporary storage for columns + groupCols := make([][]any, len(g.groupByExpr)) // group columns + aggrCols := make([][]float64, len(g.groupExpr)) // aggregate columns + + for i := range groupCols { + groupCols[i] = make([]any, 0, rowCount) + } + for i := range aggrCols { + aggrCols[i] = make([]float64, 0, rowCount) + } + + for key, accs := range g.groups { + // Add group-by (dimension) values + dims := g.keys[key] + for j, v := range dims { + groupCols[j] = append(groupCols[j], v) + } + + // Add aggregated values + for j, acc := range accs { + aggrCols[j] = append(aggrCols[j], acc.Finalize()) + } + + } + + // Now build Arrow arrays in correct schema order + fieldIndex := 0 + + // Build group-by columns first + for j := range g.groupByExpr { + colBuilders[fieldIndex] = buildDynamicArray(alloc, g.schema.Field(fieldIndex).Type, groupCols[j]) + fieldIndex++ + } + + // Build aggregate columns + for j := range g.groupExpr { + colBuilders[fieldIndex] = buildFloatArray(alloc, aggrCols[j]) + fieldIndex++ + } + + return &operators.RecordBatch{ + Schema: g.schema, + Columns: colBuilders, + RowCount: uint64(rowCount), + } +} +func buildDynamicArray(mem memory.Allocator, dt arrow.DataType, values []any) arrow.Array { + switch dt.ID() { + + // =========================== + // STRING (UTF8) + // =========================== + case arrow.STRING: + sb := array.NewStringBuilder(mem) + for _, v := range values { + if v == nil { + sb.AppendNull() + } else { + sb.Append(fmt.Sprintf("%v", v)) + } + } + return sb.NewArray() + + // =========================== + // SIGNED INTEGERS + // =========================== + case arrow.INT8: + b := array.NewInt8Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(int8)) + } + } + return b.NewArray() + + case arrow.INT16: + b := array.NewInt16Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(int16)) + } + } + return b.NewArray() + + case arrow.INT32: + b := array.NewInt32Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(int32)) + } + } + return b.NewArray() + + case arrow.INT64: + b := array.NewInt64Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(int64)) + } + } + return b.NewArray() + + // =========================== + // UNSIGNED INTEGERS + // =========================== + case arrow.UINT8: + b := array.NewUint8Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(uint8)) + } + } + return b.NewArray() + + case arrow.UINT16: + b := array.NewUint16Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(uint16)) + } + } + return b.NewArray() + + case arrow.UINT32: + b := array.NewUint32Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(uint32)) + } + } + return b.NewArray() + + case arrow.UINT64: + b := array.NewUint64Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(uint64)) + } + } + return b.NewArray() + + // =========================== + // FLOATS + // =========================== + case arrow.FLOAT32: + b := array.NewFloat32Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(float32)) + } + } + return b.NewArray() + + case arrow.FLOAT64: + b := array.NewFloat64Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(float64)) + } + } + return b.NewArray() + + // =========================== + // UNSUPPORTED TYPE + // =========================== + default: + panic(fmt.Sprintf("unsupported dynamic array type: %v", dt)) + } +} + +func buildFloatArray(mem memory.Allocator, values []float64) arrow.Array { + b := array.NewFloat64Builder(mem) + b.AppendValues(values, nil) + return b.NewArray() +} diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go b/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go index 3313b3e..10756f0 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go @@ -1,7 +1,680 @@ package aggr -import "testing" +import ( + "errors" + "fmt" + "io" + "opti-sql-go/Expr" + "opti-sql-go/operators/project" + "strings" + "testing" -func TestGroupBy(t *testing.T) { - // Simple passing test + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" +) + +func generateGroupByTestColumns() ([]string, []any) { + names := []string{ + "id", + "name", + "department", + "region", + "seniority", + "salary", + "age", + } + + // 40 IDs + ids := make([]int32, 40) + for i := range ids { + ids[i] = int32(i + 1) + } + + // Names – 40 names + namesArr := []string{ + "Alice", "Bob", "Charlie", "David", "Eve", + "Frank", "Grace", "Hannah", "Ivy", "Jake", + "Karen", "Leo", "Mona", "Nate", "Olive", + "Paul", "Quinn", "Rita", "Sam", "Tina", + "Uma", "Victor", "Wendy", "Xavier", "Yara", + "Zane", "Becky", "Carlos", "Dora", "Elias", + "Fiona", "Gabe", "Helena", "Isaac", "Julia", + "Kevin", "Lara", "Miles", "Nora", "Owen", + } + + // Randomized but balanced departments (5 groups) + departments := []string{ + "Engineering", "HR", "Sales", "Engineering", "Finance", + "Support", "Sales", "Engineering", "Support", "Finance", + "HR", "Engineering", "Sales", "Support", "Finance", + "Engineering", "Sales", "HR", "Support", "Engineering", + "Finance", "Sales", "Engineering", "Support", "HR", + "Support", "Engineering", "Finance", "Sales", "HR", + "Engineering", "Support", "Finance", "Sales", "Engineering", + "HR", "Finance", "Support", "Engineering", "Sales", + } + + // Randomized but balanced regions (4 groups) + regions := []string{ + "North", "East", "South", "West", "South", + "North", "West", "East", "North", "South", + "West", "East", "North", "South", "West", + "North", "East", "West", "South", "North", + "East", "West", "South", "North", "East", + "South", "North", "West", "East", "South", + "West", "North", "East", "South", "West", + "North", "South", "East", "West", "North", + } + + // Randomized seniority (3 groups) + seniority := []string{ + "Junior", "Senior", "Mid", "Junior", "Mid", + "Senior", "Junior", "Mid", "Senior", "Junior", + "Mid", "Senior", "Junior", "Mid", "Senior", + "Junior", "Mid", "Senior", "Junior", "Mid", + "Senior", "Junior", "Mid", "Senior", "Junior", + "Mid", "Senior", "Junior", "Mid", "Senior", + "Junior", "Mid", "Senior", "Junior", "Mid", + "Senior", "Junior", "Mid", "Senior", "Junior", + } + + // Salaries (same as before) + salaries := []float64{ + 70000, 82000, 54000, 91000, 60000, + 75000, 66000, 88000, 45000, 99000, + 72000, 81000, 53000, 86000, 64000, + 93000, 68000, 76000, 89000, 71000, + 83000, 94000, 55000, 87000, 91500, + 72000, 69000, 58000, 84000, 79000, + 81000, 78000, 62000, 97000, 82000, + 95000, 76000, 88000, 91000, 64000, + } + + // Ages with some repetition + ages := []int32{ + 28, 34, 45, 22, 31, + 29, 40, 36, 50, 26, + 33, 41, 27, 38, 24, + 46, 30, 35, 43, 32, + 39, 48, 29, 37, 42, + 28, 34, 45, 22, 31, + 29, 40, 36, 50, 26, + 39, 48, 29, 37, 42, + } + + columns := []any{ + ids, + namesArr, + departments, + regions, + seniority, + salaries, + ages, + } + + return names, columns +} + +func groupByProject() *project.InMemorySource { + names, cols := generateGroupByTestColumns() + p, _ := project.NewInMemoryProjectExec(names, cols) + return p +} + +func TestGroupByInit(t *testing.T) { + p := groupByProject() + _, _ = p.Next(12) +} + +func TestNewGroupByExecAndSchema(t *testing.T) { + // convenience builder + col := func(name string) Expr.Expression { + return Expr.NewColumnResolve(name) + } + + t.Run("single group-by single aggregate", func(t *testing.T) { + child := groupByProject() + + groupBy := []Expr.Expression{col("department")} + aggs := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("salary")}, + } + + gb, err := NewGroupByExec(child, aggs, groupBy) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + schema := gb.Schema() + if schema == nil { + t.Fatalf("schema should not be nil") + } + fmt.Println(schema) + + // group-by + 1 agg = 2 fields + if got, want := schema.NumFields(), 2; got != want { + t.Fatalf("expected %d fields, got %d", want, got) + } + + // group field + f0 := schema.Field(0) + expName := "group_" + groupBy[0].String() + if f0.Name != expName { + t.Fatalf("expected group field name %q, got %q", expName, f0.Name) + } + + // aggregate field + f1 := schema.Field(1) + properAggName := fmt.Sprintf("%s_%s", + strings.ToLower(aggrToString(int(aggs[0].AggrFunc))), + aggs[0].Child.String(), + ) + if f1.Name != properAggName { + t.Fatalf("expected agg field %q, got %q", properAggName, f1.Name) + } + + if gb.groups == nil { + t.Fatalf("groups map not initialized") + } + if gb.keys == nil { + t.Fatalf("keys map not initialized") + } + }) + + t.Run("multiple group-by and multiple aggregates", func(t *testing.T) { + child := groupByProject() + + groupBy := []Expr.Expression{col("region"), col("seniority")} + aggs := []AggregateFunctions{ + {AggrFunc: Min, Child: col("age")}, + {AggrFunc: Max, Child: col("salary")}, + {AggrFunc: Count, Child: col("id")}, + } + + gb, err := NewGroupByExec(child, aggs, groupBy) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + + schema := gb.Schema() + fmt.Printf("schema: %v\n", schema) + wantFields := len(groupBy) + len(aggs) + if schema.NumFields() != wantFields { + t.Fatalf("expected %d fields, got %d", wantFields, schema.NumFields()) + } + + // group fields first + for i, gexpr := range groupBy { + f := schema.Field(i) + exp := "group_" + gexpr.String() + if f.Name != exp { + t.Fatalf("group field[%d] mismatch: want %q got %q", i, exp, f.Name) + } + } + + // aggregate fields next + offset := len(groupBy) + for j, agg := range aggs { + f := schema.Field(offset + j) + expAggName := fmt.Sprintf("%s_%s", + strings.ToLower(aggrToString(int(agg.AggrFunc))), + agg.Child.String(), + ) + if f.Name != expAggName { + t.Fatalf("agg field name mismatch: want %q got %q", expAggName, f.Name) + } + } + }) + + t.Run("invalid group-by column triggers error", func(t *testing.T) { + child := groupByProject() + + invalidGB := []Expr.Expression{col("not_a_col")} + aggs := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("salary")}, + } + + // direct schema builder test + _, err := buildGroupBySchema(child.Schema(), invalidGB, aggs) + if err == nil { + t.Fatalf("expected error for invalid group-by expr") + } + + // NewGroupByExec should also fail + if _, err := NewGroupByExec(child, aggs, invalidGB); err == nil { + t.Fatalf("expected NewGroupByExec error for invalid group-by") + } + }) + + t.Run("no aggregates - schema should only contain group-by columns", func(t *testing.T) { + child := groupByProject() + + groupBy := []Expr.Expression{col("region")} + var aggs []AggregateFunctions + + gb, err := NewGroupByExec(child, aggs, groupBy) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + + schema := gb.Schema() + + if schema.NumFields() != 1 { + t.Fatalf("expected 1 field, got %d", schema.NumFields()) + } + + f := schema.Field(0) + exp := "group_" + groupBy[0].String() + if f.Name != exp { + t.Fatalf("wrong group field name: want %q got %q", exp, f.Name) + } + }) + + t.Run("multiple aggregates produce float64 regardless of source type", func(t *testing.T) { + child := groupByProject() + + groupBy := []Expr.Expression{col("department")} + aggs := []AggregateFunctions{ + {AggrFunc: Avg, Child: col("age")}, // int32 → float64 + {AggrFunc: Sum, Child: col("salary")}, // float64 → float64 + } + + gb, err := NewGroupByExec(child, aggs, groupBy) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + + schema := gb.Schema() + + // group-by + 2 aggregates = 3 + if schema.NumFields() != 3 { + t.Fatalf("expected 3 fields, got %d", schema.NumFields()) + } + + for idx := 1; idx < 3; idx++ { + f := schema.Field(idx) + if f.Type.ID() != arrow.FLOAT64 { + t.Fatalf("expected field[%d] to be float64, got %v", idx, f.Type) + } + } + }) + + t.Run("schema names must match exact string() output of expressions", func(t *testing.T) { + child := groupByProject() + + gbExpr := []Expr.Expression{ + Expr.NewColumnResolve("seniority"), + Expr.NewColumnResolve("region"), + } + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: Expr.NewColumnResolve("id")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + schema := gb.Schema() + + expected0 := "group_" + gbExpr[0].String() // group_Column(seniority) + expected1 := "group_" + gbExpr[1].String() // group_Column(region) + + if schema.Field(0).Name != expected0 { + t.Fatalf("wrong field[0] name: want %q got %q", expected0, schema.Field(0).Name) + } + if schema.Field(1).Name != expected1 { + t.Fatalf("wrong field[1] name: want %q got %q", expected1, schema.Field(1).Name) + } + + // count column + expectedAgg := "count_" + aggs[0].Child.String() + if schema.Field(2).Name != expectedAgg { + t.Fatalf("wrong agg field name: want %q got %q", expectedAgg, schema.Field(2).Name) + } + }) + t.Run("basic close check", func(t *testing.T) { + child := groupByProject() + + gbExpr := []Expr.Expression{ + Expr.NewColumnResolve("seniority"), + Expr.NewColumnResolve("region"), + } + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: Expr.NewColumnResolve("id")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if gb.Close() != nil { + t.Fatalf("unexpected error on close") + } + + }) +} +func TestBasicOperatorCasesGroupBy(t *testing.T) { + + t.Run("basic close check", func(t *testing.T) { + child := groupByProject() + + gbExpr := []Expr.Expression{ + Expr.NewColumnResolve("seniority"), + Expr.NewColumnResolve("region"), + } + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: Expr.NewColumnResolve("id")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if gb.Close() != nil { + t.Fatalf("unexpected error on close") + } + + }) + t.Run("done case", func(t *testing.T) { + child := groupByProject() + + gbExpr := []Expr.Expression{ + Expr.NewColumnResolve("seniority"), + Expr.NewColumnResolve("region"), + } + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: Expr.NewColumnResolve("id")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + gb.done = true + _, err = gb.Next(100) + if err == nil || !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF but received %v", err) + } + + }) +} +func TestGroupByNext_SingleColumnCount(t *testing.T) { + col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } + + child := groupByProject() + + gbExpr := []Expr.Expression{col("region")} + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: col("id")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + batch, _ := gb.Next(1000) + + if batch == nil || batch.RowCount == 0 { + t.Fatalf("expected non-empty grouped result") + } + + // Validate schema + if batch.Schema.NumFields() != 2 { + t.Fatalf("expected 2 fields, got %d", batch.Schema.NumFields()) + } + + // Validate that group keys exist and aggregates exist + if batch.Columns[0].Len() == 0 { + t.Fatalf("expected region groups") + } + + if batch.Columns[1].Len() == 0 { + t.Fatalf("expected aggregated counts") + } +} + +func TestGroupByNext_MultipleGroupBy_MultipleAggs(t *testing.T) { + col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } + + child := groupByProject() + + gbExpr := []Expr.Expression{ + col("seniority"), + col("region"), + } + + aggs := []AggregateFunctions{ + {AggrFunc: Min, Child: col("age")}, + {AggrFunc: Max, Child: col("salary")}, + {AggrFunc: Count, Child: col("id")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatal(err) + } + + batch, _ := gb.Next(50) + + if batch.RowCount == 0 { + t.Fatalf("expected non-zero grouped rows") + } + + if batch.Schema.NumFields() != 5 { + t.Fatalf("expected 5 fields (2 group-by + 3 aggr), got %d", batch.Schema.NumFields()) + } +} + +func TestGroupByNext_MultipleNextCalls(t *testing.T) { + col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } + + child := groupByProject() + + gbExpr := []Expr.Expression{col("region")} + aggs := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("salary")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatal(err) + } + + // First call returns batch + EOF + _, _ = gb.Next(100) + _, err = gb.Next(100) + if !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF on second return, got %v", err) + } + +} + +func TestBuildGroupBySchema_AllBranches(t *testing.T) { + col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } + + child := groupByProject() + + groupBy := []Expr.Expression{col("region"), col("seniority")} + aggs := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("salary")}, + {AggrFunc: Count, Child: col("id")}, + } + + schema, err := buildGroupBySchema(child.Schema(), groupBy, aggs) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if schema.NumFields() != 4 { + t.Fatalf("expected 4 fields got %d", schema.NumFields()) + } + + // test group-by fields + if schema.Field(0).Type.ID() != arrow.STRING { + t.Fatalf("expected STRING for region") + } + + // aggregated fields always float64 + if schema.Field(2).Type.ID() != arrow.FLOAT64 { + t.Fatalf("expected FLOAT64 for aggregate field") + } +} + +func TestBuildGroupBySchema_InvalidColumn(t *testing.T) { + col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } + child := groupByProject() + + _, err := buildGroupBySchema(child.Schema(), []Expr.Expression{col("doesnotexist")}, nil) + if err == nil { + t.Fatalf("expected error but got none") + } +} + +func TestBuildGroupBySchema_InvalidAggType(t *testing.T) { + col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } + child := groupByProject() + + aggs := []AggregateFunctions{ + // Boolean type or unsupported type + {AggrFunc: Sum, Child: col("name")}, + } + + _, err := buildGroupBySchema(child.Schema(), nil, aggs) + if err == nil { + t.Fatalf("expected invalid agg type error") + } +} +func TestGetValue_AllTypes(t *testing.T) { + mem := memory.NewGoAllocator() + + // int32 + i32 := array.NewInt32Builder(mem) + i32.Append(42) + arr32 := i32.NewArray() + if getValue(arr32, 0).(int32) != 42 { + t.Fatal("failed int32 case") + } + + // int64 + i64 := array.NewInt64Builder(mem) + i64.Append(99) + arr64 := i64.NewArray() + if getValue(arr64, 0).(int64) != 99 { + t.Fatal("failed int64 case") + } + + // float32 + f32 := array.NewFloat32Builder(mem) + f32.Append(3.5) + arrf32 := f32.NewArray() + if getValue(arrf32, 0).(float32) != 3.5 { + t.Fatal("failed float32 case") + } + + // float64 + f64 := array.NewFloat64Builder(mem) + f64.Append(9.1) + arrf64 := f64.NewArray() + if getValue(arrf64, 0).(float64) != 9.1 { + t.Fatal("failed float64 case") + } + + // string + sb := array.NewStringBuilder(mem) + sb.Append("hello") + sarr := sb.NewArray() + if getValue(sarr, 0).(string) != "hello" { + t.Fatal("failed string case") + } + + // boolean + bb := array.NewBooleanBuilder(mem) + bb.Append(true) + barr := bb.NewArray() + if getValue(barr, 0).(bool) != true { + t.Fatal("failed boolean case") + } +} + +func TestBuildDynamicArray_AllPrimitiveTypes(t *testing.T) { + mem := memory.NewGoAllocator() + + tests := []struct { + dt arrow.DataType + val []any + }{ + {arrow.PrimitiveTypes.Int8, []any{int8(1), nil, int8(3)}}, + {arrow.PrimitiveTypes.Int16, []any{int16(2), int16(5)}}, + {arrow.PrimitiveTypes.Int32, []any{int32(10), nil, int32(12)}}, + {arrow.PrimitiveTypes.Int64, []any{int64(20), int64(40)}}, + + {arrow.PrimitiveTypes.Uint8, []any{uint8(7), nil}}, + {arrow.PrimitiveTypes.Uint16, []any{uint16(100)}}, + {arrow.PrimitiveTypes.Uint32, []any{uint32(2000)}}, + {arrow.PrimitiveTypes.Uint64, []any{uint64(99999)}}, + + {arrow.PrimitiveTypes.Float32, []any{float32(2.2), nil}}, + {arrow.PrimitiveTypes.Float64, []any{float64(9.9)}}, + + {arrow.BinaryTypes.String, []any{"a", "b", nil}}, + } + + for _, tc := range tests { + arr := buildDynamicArray(mem, tc.dt, tc.val) + if arr.Len() != len(tc.val) { + t.Fatalf("wrong length for type %v", tc.dt) + } + } +} + +func TestCreateAccumulator_AllCases(t *testing.T) { + funcs := []AggrFunc{Min, Max, Sum, Count, Avg} + + for _, fn := range funcs { + acc := createAccumulator(fn) + if acc == nil { + t.Fatalf("expected accumulator for fn=%v", fn) + } + } +} + +func TestCreateAccumulator_PanicOnInvalid(t *testing.T) { + defer func() { + if recover() == nil { + t.Fatalf("expected panic for invalid function") + } + }() + + createAccumulator(AggrFunc(9999)) // invalid +} + +func TestBuildGroupByOutput_Basic(t *testing.T) { + col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } + child := groupByProject() + + gbExpr := []Expr.Expression{col("region")} + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: col("id")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatal(err) + } + + // invoke Next (fills accumulators) + _, _ = gb.Next(100) + + batch := buildGroupByOutput(gb) + + if batch.RowCount == 0 { + t.Fatalf("expected grouped rows") + } + + if len(batch.Columns) != 2 { + t.Fatalf("expected 2 columns, got %d", len(batch.Columns)) + } } diff --git a/src/Backend/opti-sql-go/operators/aggr/having.go b/src/Backend/opti-sql-go/operators/aggr/having.go new file mode 100644 index 0000000..a2a559f --- /dev/null +++ b/src/Backend/opti-sql-go/operators/aggr/having.go @@ -0,0 +1,77 @@ +package aggr + +import ( + "errors" + "io" + "opti-sql-go/Expr" + "opti-sql-go/operators" + "opti-sql-go/operators/filter" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + +// carbon copy of filter.go with minor changes to fit having semantics +var ( + _ = (operators.Operator)(&HavingExec{}) +) + +type HavingExec struct { + input operators.Operator + schema *arrow.Schema + + havingExpr Expr.Expression + done bool +} + +func NewHavingExec(input operators.Operator, havingFilter Expr.Expression) (*HavingExec, error) { + + return &HavingExec{ + input: input, + schema: input.Schema(), + havingExpr: havingFilter, + }, nil +} + +func (h *HavingExec) Next(n uint16) (*operators.RecordBatch, error) { + if h.done { + return nil, io.EOF + } + childBatch, err := h.input.Next(n) + if err != nil { + if errors.Is(err, io.EOF) { + h.done = true + } + return nil, err + } + booleanMask, err := Expr.EvalExpression(h.havingExpr, childBatch) + if err != nil { + return nil, err + } + boolArr, ok := booleanMask.(*array.Boolean) // impossible for this to not be a boolean array,assuming validPredicates works as it should + if !ok { + return nil, errors.New("having predicate did not evaluate to boolean array") + } + filteredCol := make([]arrow.Array, len(childBatch.Columns)) + for i, col := range childBatch.Columns { + filteredCol[i], err = filter.ApplyBooleanMask(col, boolArr) + if err != nil { + return nil, err + } + } + // release old columns + operators.ReleaseArrays(childBatch.Columns) + size := uint64(filteredCol[0].Len()) + + return &operators.RecordBatch{ + Schema: childBatch.Schema, + Columns: filteredCol, + RowCount: size, + }, nil +} +func (h *HavingExec) Schema() *arrow.Schema { + return h.schema +} +func (h *HavingExec) Close() error { + return h.input.Close() +} diff --git a/src/Backend/opti-sql-go/operators/aggr/having_test.go b/src/Backend/opti-sql-go/operators/aggr/having_test.go new file mode 100644 index 0000000..9321639 --- /dev/null +++ b/src/Backend/opti-sql-go/operators/aggr/having_test.go @@ -0,0 +1,213 @@ +package aggr + +import ( + "errors" + "io" + "strings" + "testing" + + "opti-sql-go/Expr" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + +func TestHavingExec_OnGroupBy(t *testing.T) { + + // ============================================================= + // 1) HAVING SUM(salary) > 600000 + // ============================================================= + t.Run("having_sum_salary_gt_600k", func(t *testing.T) { + + child := groupByProject() + + groupBy := []Expr.Expression{col("department")} + aggs := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("salary")}, + } + + gb, err := NewGroupByExec(child, aggs, groupBy) + if err != nil { + t.Fatalf("unexpected GroupBy error: %v", err) + } + + sumCol := "sum_Column(salary)" + + // SUM(salary) > 600000 + havingExpr := Expr.NewBinaryExpr( + Expr.NewColumnResolve(sumCol), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, float64(600000)), + ) + + having, err := NewHavingExec(gb, havingExpr) + if err != nil { + t.Fatalf("unexpected HavingExec init error: %v", err) + } + + batch, err := having.Next(1024) + if err != nil { + t.Fatalf("unexpected error running Next: %v", err) + } + t.Logf("batch : %v\n", batch.PrettyPrint()) + sumValues := batch.Columns[1].(*array.Float64) + for i := 0; i < sumValues.Len(); i++ { + if sumValues.Value(i) <= 600000 { + t.Fatalf("expected sum(salary) > 600000, got %f", sumValues.Value(i)) + } + } + + }) + + // ============================================================= + // 2) HAVING COUNT(id) >= 10 + // ============================================================= + t.Run("having_count_id_ge_10", func(t *testing.T) { + + child := groupByProject() + + groupBy := []Expr.Expression{col("region")} + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: col("id")}, + } + + gb, err := NewGroupByExec(child, aggs, groupBy) + if err != nil { + t.Fatalf("unexpected GroupBy err: %v", err) + } + + countCol := "count_Column(id)" + + havingExpr := Expr.NewBinaryExpr( + Expr.NewColumnResolve(countCol), + Expr.GreaterThanOrEqual, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, float64(10)), + ) + + having, err := NewHavingExec(gb, havingExpr) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + batch, err := having.Next(200) + if err != nil { + t.Fatalf("unexpected Next error: %v", err) + } + + if batch.RowCount != 3 { // North, South, West ≥ 10 + t.Fatalf("expected 3 regions with >=10 rows, got %d", batch.RowCount) + } + }) + + // ============================================================= + // 3) HAVING filters all groups out + // ============================================================= + t.Run("having_filters_all", func(t *testing.T) { + + child := groupByProject() + + groupBy := []Expr.Expression{col("department")} + aggs := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("salary")}, + } + + gb, _ := NewGroupByExec(child, aggs, groupBy) + + sumCol := "sum_Column(salary)" + + // Impossible condition + havingExpr := Expr.NewBinaryExpr( + Expr.NewColumnResolve(sumCol), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, float64(1_000_000_000)), + ) + + having, _ := NewHavingExec(gb, havingExpr) + + batch, err := having.Next(1024) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + if batch.RowCount != 0 { + t.Fatalf("expected all rows to be filtered out, got %d", batch.RowCount) + } + }) + + // ============================================================= + // 4) Non-boolean predicate → error + // ============================================================= + t.Run("having_non_boolean_predicate", func(t *testing.T) { + + child := groupByProject() + groupBy := []Expr.Expression{col("department")} + aggs := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("salary")}, + } + + gb, _ := NewGroupByExec(child, aggs, groupBy) + + // invalid: resolves to float, not boolean + invalidExpr := Expr.NewColumnResolve("sum_Column(salary)") + + having, _ := NewHavingExec(gb, invalidExpr) + + _, err := having.Next(100) + if err == nil { + t.Fatalf("expected non-boolean error, got nil") + } + if !strings.Contains(err.Error(), "boolean") { + t.Fatalf("expected boolean error, got: %v", err) + } + }) + + // ============================================================= + // 5) done = true returns EOF + // ============================================================= + t.Run("done_returns_eof", func(t *testing.T) { + + child := groupByProject() + + groupBy := []Expr.Expression{col("region")} + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: col("id")}, + } + + gb, _ := NewGroupByExec(child, aggs, groupBy) + + countCol := "count_Column(id)" + + havingExpr := Expr.NewBinaryExpr( + Expr.NewColumnResolve(countCol), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, float64(0)), + ) + + h, _ := NewHavingExec(gb, havingExpr) + h.done = true + + _, err := h.Next(10) + if !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF, got: %v", err) + } + }) + + // ============================================================= + // 6) Close forwards to child.Close() + // ============================================================= + t.Run("close_propagates", func(t *testing.T) { + + child := groupByProject() + + gb, _ := NewGroupByExec(child, []AggregateFunctions{ + {AggrFunc: Count, Child: col("id")}, + }, []Expr.Expression{col("region")}) + + h, _ := NewHavingExec(gb, Expr.NewLiteralResolve(arrow.FixedWidthTypes.Boolean, true)) + + if err := h.Close(); err != nil { + t.Fatalf("Close returned error: %v", err) + } + t.Log(h.Schema()) + }) +} diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go new file mode 100644 index 0000000..1fcccdd --- /dev/null +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go @@ -0,0 +1,291 @@ +package aggr + +import ( + "context" + "errors" + "fmt" + "io" + "opti-sql-go/Expr" + "opti-sql-go/operators" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/compute" +) + +var ( + ErrUnsupportedAggrFunc = func(aggr int) error { + return fmt.Errorf("%d is an unsupported aggregate function", aggr) + } + ErrInvalidAggrColumnType = func(value any) error { + return fmt.Errorf("%v of type %T cannot be cast to float64 so it is not a valid column type to aggregate on", value, value) + } +) + +// AggrFunc represents the type of aggregation function to be performed. +type AggrFunc int + +const ( + Min AggrFunc = iota + Max + Count + Sum + Avg +) + +var ( + _ = (accumulator)(&minAggrAccumulator{}) + _ = (accumulator)(&maxAggrAccumulator{}) + _ = (accumulator)(&countAggrAccumulator{}) + _ = (accumulator)(&sumAggrAccumulator{}) + _ = (accumulator)(&avgAggrAccumulator{}) + _ = (operators.Operator)(&AggrExec{}) +) + +func NewAggregateFunctions(aggrFunc AggrFunc, child Expr.Expression) AggregateFunctions { + return AggregateFunctions{ + AggrFunc: aggrFunc, + Child: child, + } +} + +type AggregateFunctions struct { + AggrFunc AggrFunc // switch to deal with separate aggregate functions + Child Expr.Expression // resolves to a column generally +} +type accumulator interface { + Update(value float64) + Finalize() float64 +} + +func newMinAggr() accumulator { + return &minAggrAccumulator{} +} + +type minAggrAccumulator struct { + minV float64 + firstValue bool +} + +func (m *minAggrAccumulator) Update(value float64) { + if !m.firstValue { + m.minV = value + m.firstValue = true + return + } + m.minV = min(m.minV, value) + +} +func (m *minAggrAccumulator) Finalize() float64 { return m.minV } +func newMaxAggr() accumulator { + return &maxAggrAccumulator{} +} + +type maxAggrAccumulator struct { + maxV float64 + firstValue bool +} + +func (m *maxAggrAccumulator) Update(value float64) { + if !m.firstValue { + m.maxV = value + m.firstValue = true + return + } + m.maxV = max(m.maxV, value) +} +func (m *maxAggrAccumulator) Finalize() float64 { return m.maxV } + +func newCountAggr() accumulator { + return &countAggrAccumulator{} +} + +type countAggrAccumulator struct { + count float64 +} + +func (c *countAggrAccumulator) Update(_ float64) { + c.count++ +} +func (c *countAggrAccumulator) Finalize() float64 { return c.count } + +func newSumAggr() accumulator { + return &sumAggrAccumulator{} +} + +type sumAggrAccumulator struct { + summation float64 +} + +func (s *sumAggrAccumulator) Update(value float64) { + s.summation += value +} +func (s *sumAggrAccumulator) Finalize() float64 { return s.summation } +func newAvgAggr() accumulator { + return &avgAggrAccumulator{} +} + +type avgAggrAccumulator struct { + used bool + values float64 + count float64 +} + +func (a *avgAggrAccumulator) Update(value float64) { + a.used = true + a.values += value + a.count++ +} +func (a *avgAggrAccumulator) Finalize() float64 { + // handles divide by zero + if !a.used { + return 0.0 + } + return a.values / a.count +} + +// =================== +// Aggregator Operator +// =================== +// handles global aggregations without group by +type AggrExec struct { + input operators.Operator // child operator + schema *arrow.Schema // output schema + aggExpressions []AggregateFunctions // list of wanted aggregate expressions + accumulators []accumulator // list of accumulators corresponding to aggExpressions, these will actually work to compute the aggregation + done bool // know when to return io.EOF +} + +func NewGlobalAggrExec(child operators.Operator, aggExprs []AggregateFunctions) (*AggrExec, error) { + accs := make([]accumulator, len(aggExprs)) + fields := make([]arrow.Field, len(aggExprs)) + for i, agg := range aggExprs { + dt, err := Expr.ExprDataType(agg.Child, child.Schema()) + if err != nil || !validAggrType(dt) { + return nil, ErrInvalidAggrColumnType(dt) + } + var fieldName string + switch agg.AggrFunc { + case Min: + fieldName = fmt.Sprintf("min_%s", agg.Child.String()) + accs[i] = newMinAggr() + case Max: + fieldName = fmt.Sprintf("max_%s", agg.Child.String()) + accs[i] = newMaxAggr() + case Count: + fieldName = fmt.Sprintf("count_%s", agg.Child.String()) + accs[i] = newCountAggr() + case Sum: + fieldName = fmt.Sprintf("sum_%s", agg.Child.String()) + accs[i] = newSumAggr() + case Avg: + fieldName = fmt.Sprintf("avg_%s", agg.Child.String()) + accs[i] = newAvgAggr() + + default: + return nil, ErrUnsupportedAggrFunc(int(agg.AggrFunc)) + } + fields[i] = arrow.Field{ + Name: fieldName, + Type: arrow.PrimitiveTypes.Float64, + Nullable: true, + } + } + return &AggrExec{ + input: child, + schema: arrow.NewSchema(fields, nil), + aggExpressions: aggExprs, + accumulators: accs, + }, nil +} + +// Next consumes all batches from the child operator, evaluates the aggregate expressions, +// updates the accumulators for each value, and returns a single output batch containing +// the final aggregation results. It returns io.EOF after producing the result batch. +func (a *AggrExec) Next(n uint16) (*operators.RecordBatch, error) { + if a.done { + return nil, io.EOF + } + for { + childBatch, err := a.input.Next(n) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, err + } + for i, aggExpr := range a.aggExpressions { + agrArray, err := Expr.EvalExpression(aggExpr.Child, childBatch) + if err != nil { + return nil, err + } + agrArray, err = castArrayToFloat64(agrArray) + if err != nil { + return nil, err + } + valueArray := agrArray.(*array.Float64) + accumulator := a.accumulators[i] + for j := 0; j < valueArray.Len(); j++ { + if valueArray.IsNull(j) { + continue + } + accumulator.Update(valueArray.Value(j)) + } + + } + operators.ReleaseArrays(childBatch.Columns) + } + // build array with just the result of the column + resultColumns := make([]arrow.Array, len(a.accumulators)) + for i := range a.accumulators { + resultColumns[i] = operators.NewRecordBatchBuilder().GenFloatArray(a.accumulators[i].Finalize()) + } + a.done = true + return &operators.RecordBatch{ + Schema: a.schema, + Columns: resultColumns, + RowCount: 1, + }, nil +} + +func (a *AggrExec) Schema() *arrow.Schema { + return a.schema +} +func (a *AggrExec) Close() error { + return a.input.Close() +} + +func validAggrType(dt arrow.DataType) bool { + switch dt.ID() { + case arrow.UINT8, arrow.UINT16, arrow.UINT32, arrow.UINT64, + arrow.INT8, arrow.INT16, arrow.INT32, arrow.INT64, arrow.FLOAT16, arrow.FLOAT32, arrow.FLOAT64: + return true + default: + return false + } +} + +func castArrayToFloat64(arr arrow.Array) (arrow.Array, error) { + outDatum, err := compute.CastArray(context.TODO(), arr, compute.NewCastOptions(&arrow.Float64Type{}, true)) + if err != nil { + return nil, err + } + + return outDatum, nil +} +func aggrToString(t int) string { + switch AggrFunc(t) { + case Min: + return "MIN" + case Max: + return "MAX" + case Count: + return "COUNT" + case Sum: + return "SUM" + case Avg: + return "AVG" + default: + return "UNKNOWN_AGGREGATE_FUNCTION" + } +} diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go new file mode 100644 index 0000000..9b5af24 --- /dev/null +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go @@ -0,0 +1,606 @@ +package aggr + +import ( + "errors" + "fmt" + "io" + "math" + "opti-sql-go/Expr" + "opti-sql-go/operators/project" + "testing" + + "github.com/apache/arrow/go/v15/arrow/memory" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + +func generateAggTestColumns() ([]string, []any) { + names := []string{ + "id", + "name", + "age", + "salary", + } + + columns := []any{ + // id: 1 to 25 + []int32{ + 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, + }, + + // name: 25 people + []string{ + "Alice", "Bob", "Charlie", "David", "Eve", + "Frank", "Grace", "Hannah", "Ivy", "Jake", + "Karen", "Leo", "Mona", "Nate", "Olive", + "Paul", "Quinn", "Rita", "Sam", "Tina", + "Uma", "Victor", "Wendy", "Xavier", "Yara", + }, + + // age: 25 numeric values + []int32{ + 28, 34, 45, 22, 31, + 29, 40, 36, 50, 26, + 33, 41, 27, 38, 24, + 46, 30, 35, 43, 32, + 39, 48, 29, 37, 42, + }, + + // salary: 25 numeric values + []float64{ + 70000.0, 82000.5, 54000.0, 91000.0, 60000.0, + 75000.0, 66000.0, 88000.0, 45000.0, 99000.0, + 72000.0, 81000.0, 53000.0, 86000.0, 64000.0, + 93000.0, 68000.0, 76000.0, 89000.0, 71000.0, + 83000.0, 94000.0, 55000.0, 87000.0, 91500.0, + }, + } + + return names, columns +} +func generateAggTestColumnsWithNulls(mem memory.Allocator) ([]string, []arrow.Array) { + names := []string{"id", "name", "age", "salary"} + + // ------------------------- + // id column (int32) + // ------------------------- + idB := array.NewInt32Builder(mem) + idVals := []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + idValid := []bool{ + true, true, false, true, true, + false, true, true, true, false, + } + idB.AppendValues(idVals, idValid) + idArr := idB.NewArray() + + // ------------------------- + // name column (string) + // ------------------------- + nameB := array.NewStringBuilder(mem) + nameVals := []string{ + "Alice", "Bob", "Charlie", "David", "Eve", + "Frank", "Grace", "Hannah", "Ivy", "Jake", + } + nameValid := []bool{ + true, true, true, false, true, + true, true, true, false, true, + } + nameB.AppendValues(nameVals, nameValid) + nameArr := nameB.NewArray() + + // ------------------------- + // age column (int32) + // ------------------------- + ageB := array.NewInt32Builder(mem) + ageVals := []int32{28, 34, 45, 22, 31, 29, 40, 36, 50, 26} + ageValid := []bool{ + true, false, true, true, true, + true, false, true, true, true, + } + ageB.AppendValues(ageVals, ageValid) + ageArr := ageB.NewArray() + + // ------------------------- + // salary column (float64) + // ------------------------- + salB := array.NewFloat64Builder(mem) + salVals := []float64{ + 70000, 82000, 54000, 91000, 60000, + 75000, 66000, 0, 45000, 99000, + } + + salaryValid := []bool{ + true, true, true, true, true, + true, true, false, true, true, + } + + salB.AppendValues(salVals, salaryValid) + salaryArr := salB.NewArray() + + return names, []arrow.Array{idArr, nameArr, ageArr, salaryArr} +} + +func aggProject() *project.InMemorySource { + names, cols := generateAggTestColumns() + p, _ := project.NewInMemoryProjectExec(names, cols) + return p +} + +func aggProjectNull() *project.InMemorySource { + names, arr := generateAggTestColumnsWithNulls(memory.NewGoAllocator()) + p, _ := project.NewInMemoryProjectExecFromArrays(names, arr) + return p +} + +func col(name string) Expr.Expression { + return Expr.NewColumnResolve(name) +} + +func TestNewAggrExec(t *testing.T) { + + // ----------------------------------------------------------------- + t.Run("valid_single_min", func(t *testing.T) { + child := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("age")}, + } + + exec, err := NewGlobalAggrExec(child, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if exec.Schema().NumFields() != 1 { + t.Fatalf("expected 1 schema field, got %d", exec.Schema().NumFields()) + } + + expectedName := "min_Column(age)" + if exec.Schema().Field(0).Name != expectedName { + t.Fatalf("expected name %s, got %s", + expectedName, exec.Schema().Field(0).Name) + } + }) + + // ----------------------------------------------------------------- + t.Run("multiple_aggregations_schema_names", func(t *testing.T) { + child := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("id")}, + {AggrFunc: Max, Child: col("salary")}, + {AggrFunc: Avg, Child: col("age")}, + } + + exec, err := NewGlobalAggrExec(child, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + schema := exec.Schema() + + expected := []string{ + "min_Column(id)", + "max_Column(salary)", + "avg_Column(age)", + } + + for i, f := range schema.Fields() { + if f.Name != expected[i] { + t.Fatalf("expected field %s, got %s", expected[i], f.Name) + } + } + }) + + // ----------------------------------------------------------------- + t.Run("invalid_type_detection_string_column", func(t *testing.T) { + child := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("name")}, // "name" is string → invalid + } + + _, err := NewGlobalAggrExec(child, agg) + if err == nil { + t.Fatalf("expected type error, got nil") + } + t.Logf("================\n invalid column err %v \n ============", err) + }) + + // ----------------------------------------------------------------- + t.Run("unsupported_aggregate_function", func(t *testing.T) { + child := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: 9999, Child: col("age")}, + } + + _, err := NewGlobalAggrExec(child, agg) + if err == nil { + t.Fatalf("expected unsupported aggr error") + } + }) + + // ----------------------------------------------------------------- + t.Run("schema_type_float64_for_all_numeric_aggs", func(t *testing.T) { + child := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("id")}, + {AggrFunc: Max, Child: col("salary")}, + {AggrFunc: Sum, Child: col("age")}, + {AggrFunc: Avg, Child: col("salary")}, + {AggrFunc: Count, Child: col("age")}, + } + + exec, err := NewGlobalAggrExec(child, agg) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + + for _, f := range exec.Schema().Fields() { + if f.Type.ID() != arrow.FLOAT64 { + t.Fatalf("expected float64 output type, got %s", f.Type) + } + } + if err := exec.Close(); err != nil { + t.Fatalf("unexpected close error: %v", err) + } + }) + + // ----------------------------------------------------------------- + t.Run("check_all_valid_numeric_types_pass", func(t *testing.T) { + + // all numeric arrow types accepted by validAggrType() + validTypes := []arrow.DataType{ + arrow.PrimitiveTypes.Uint8, + arrow.PrimitiveTypes.Uint16, + arrow.PrimitiveTypes.Uint32, + arrow.PrimitiveTypes.Uint64, + arrow.PrimitiveTypes.Int8, + arrow.PrimitiveTypes.Int16, + arrow.PrimitiveTypes.Int32, + arrow.PrimitiveTypes.Int64, + arrow.PrimitiveTypes.Float32, + arrow.PrimitiveTypes.Float64, + } + + fieldNames := make([]string, len(validTypes)) + colData := make([]any, len(validTypes)) + + for i, dt := range validTypes { + name := fmt.Sprintf("col_%d", i) + fieldNames[i] = name + + switch dt.ID() { + case arrow.UINT8: + colData[i] = []uint8{1} + case arrow.UINT16: + colData[i] = []uint16{1} + case arrow.UINT32: + colData[i] = []uint32{1} + case arrow.UINT64: + colData[i] = []uint64{1} + case arrow.INT8: + colData[i] = []int8{1} + case arrow.INT16: + colData[i] = []int16{1} + case arrow.INT32: + colData[i] = []int32{1} + case arrow.INT64: + colData[i] = []int64{1} + case arrow.FLOAT16: + // float16 stored as float32 in Go + colData[i] = []float32{1} + case arrow.FLOAT32: + colData[i] = []float32{1} + case arrow.FLOAT64: + colData[i] = []float64{1} + } + } + + src, _ := project.NewInMemoryProjectExec(fieldNames, colData) + + for i := range fieldNames { + agg := []AggregateFunctions{ + {AggrFunc: Sum, Child: col(fieldNames[i])}, + } + + _, err := NewGlobalAggrExec(src, agg) + if err != nil { + t.Fatalf("unexpected error for type %s: %v", validTypes[i], err) + } + } + }) +} + +func TestCastArrayToFloat64(t *testing.T) { + + alloc := memory.NewGoAllocator + + // -------------------------------------------------------- + t.Run("cast_int32_to_float64", func(t *testing.T) { + b := array.NewInt32Builder(alloc()) + b.AppendValues([]int32{1, 2, 3, 4}, nil) + arr := b.NewArray() + + out, err := castArrayToFloat64(arr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + farr, ok := out.(*array.Float64) + if !ok { + t.Fatalf("expected Float64 array, got %T", out) + } + + expected := []float64{1, 2, 3, 4} + for i := range expected { + if farr.Value(i) != expected[i] { + t.Fatalf("expected %v at %d, got %v", expected[i], i, farr.Value(i)) + } + } + }) + + // -------------------------------------------------------- + t.Run("cast_float32_to_float64", func(t *testing.T) { + b := array.NewFloat32Builder(alloc()) + b.AppendValues([]float32{10.5, 20.5, 30.5}, nil) + arr := b.NewArray() + + out, err := castArrayToFloat64(arr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + farr, ok := out.(*array.Float64) + if !ok { + t.Fatalf("expected Float64 array, got %T", out) + } + + expected := []float64{10.5, 20.5, 30.5} + for i := range expected { + if farr.Value(i) != expected[i] { + t.Fatalf("expected %v at %d, got %v", expected[i], i, farr.Value(i)) + } + } + }) + + // -------------------------------------------------------- + t.Run("invalid_string_cast", func(t *testing.T) { + b := array.NewStringBuilder(alloc()) + b.AppendValues([]string{"a", "b", "c"}, nil) + arr := b.NewArray() + + _, err := castArrayToFloat64(arr) + if err == nil { + t.Fatalf("expected error when casting string array to float64") + } + }) + + // -------------------------------------------------------- + t.Run("empty_array_cast", func(t *testing.T) { + b := array.NewInt32Builder(alloc()) + // no values appended + arr := b.NewArray() + + out, err := castArrayToFloat64(arr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + _, ok := out.(*array.Float64) + if !ok { + t.Fatalf("expected Float64 array for empty cast, got %T", out) + } + + if out.Len() != 0 { + t.Fatalf("expected empty array, got length %d", out.Len()) + } + }) + +} + +func TestAggregateExecNext(t *testing.T) { + t.Run("validating done case early", func(t *testing.T) { + proj := aggProject() + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("id")}} + aggrExec, err := NewGlobalAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + aggrExec.done = true + _, err = aggrExec.Next(10) + if err == nil || !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF error, got nil") + } + }) + t.Run("Aggr minimum value on age", func(t *testing.T) { + proj := aggProject() + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("age")}} + aggrExec, err := NewGlobalAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + resultBatch, _ := aggrExec.Next(100) + t.Logf("record batch: %v\n", resultBatch) + if resultBatch.Columns[0].(*array.Float64).Value(0) != 22 { + t.Fatalf("expected minimum age 22, got %v", resultBatch.Columns[0].(*array.Float64).Value(0)) + } + + }) + t.Run("Aggr maximum salary", func(t *testing.T) { + proj := aggProject() + agg := []AggregateFunctions{ + {AggrFunc: Max, Child: col("salary")}, + } + + aggrExec, err := NewGlobalAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + resultBatch, _ := aggrExec.Next(100) + + maxSalary := resultBatch.Columns[0].(*array.Float64).Value(0) + if maxSalary != 99000.0 && maxSalary != 94000.0 && maxSalary != 93000.0 { + // Real max is 99000 (Jake has 99000) + t.Fatalf("expected max salary 99000, got %v", maxSalary) + } + }) + t.Run("Aggr sum of id column", func(t *testing.T) { + proj := aggProject() + agg := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("id")}, + } + + aggrExec, err := NewGlobalAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + resultBatch, _ := aggrExec.Next(200) + + sumIDs := resultBatch.Columns[0].(*array.Float64).Value(0) + expected := float64((25 * 26) / 2) // sum(1..25) = 325 + if sumIDs != expected { + t.Fatalf("expected sum 325, got %v", sumIDs) + } + }) + t.Run("Aggr count of age column", func(t *testing.T) { + proj := aggProject() + agg := []AggregateFunctions{ + NewAggregateFunctions(Count, col("age")), + } + + aggrExec, err := NewGlobalAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + resultBatch, _ := aggrExec.Next(300) + + count := resultBatch.Columns[0].(*array.Float64).Value(0) + if count != 25 { + t.Fatalf("expected count 25, got %v", count) + } + }) + t.Run("Aggr average of salary ", func(t *testing.T) { + proj := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: Avg, Child: col("salary")}, + } + + aggrExec, err := NewGlobalAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + resultBatch, _ := aggrExec.Next(500) + + avg := resultBatch.Columns[0].(*array.Float64).Value(0) + expected := 75740.02 + + if math.Abs(avg-expected) > 0.001 { + t.Fatalf("expected avg %v, got %v", expected, avg) + } + + }) + t.Run("Multiple aggregators in a single request", func(t *testing.T) { + proj := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("age")}, + {AggrFunc: Max, Child: col("salary")}, + {AggrFunc: Count, Child: col("id")}, + } + + aggrExec, err := NewGlobalAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + resultBatch, _ := aggrExec.Next(1000) + + minAge := resultBatch.Columns[0].(*array.Float64).Value(0) + maxSalary := resultBatch.Columns[1].(*array.Float64).Value(0) + countIDs := resultBatch.Columns[2].(*array.Float64).Value(0) + + if minAge != 22 { + t.Fatalf("expected min age 22, got %v", minAge) + } + if maxSalary != 99000.0 { + t.Fatalf("expected max salary 99000, got %v", maxSalary) + } + if countIDs != 25 { + t.Fatalf("expected count 25, got %v", countIDs) + } + }) + + // ========================================================== + t.Run("Schema correctness for multiple aggregates", func(t *testing.T) { + proj := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("id")}, + {AggrFunc: Sum, Child: col("age")}, + {AggrFunc: Count, Child: col("salary")}, + } + + aggrExec, err := NewGlobalAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + + s := aggrExec.Schema() + + expectedNames := []string{ + "min_Column(id)", + "sum_Column(age)", + "count_Column(salary)", + } + + for i, f := range s.Fields() { + if f.Name != expectedNames[i] { + t.Fatalf("expected field %s, got %s", expectedNames[i], f.Name) + } + if f.Type.ID() != arrow.FLOAT64 { + t.Fatalf("expected float64 fields only") + } + } + }) +} + +func TestAggregateExecNull(t *testing.T) { + + t.Run("Aggr count of age column", func(t *testing.T) { + proj := aggProjectNull() + agg := []AggregateFunctions{ + NewAggregateFunctions(Count, col("age")), + NewAggregateFunctions(Sum, col("id")), + } + + aggrExec, err := NewGlobalAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + resultBatch, _ := aggrExec.Next(100) + t.Logf("rb:%v\n", resultBatch) + count := resultBatch.Columns[0].(*array.Float64).Value(0) + if count != 8 { + t.Fatalf("expected count 7, got %v", count) + } + sumIDs := resultBatch.Columns[1].(*array.Float64).Value(0) + expectedSum := float64(1 + 2 + 4 + 5 + 7 + 8 + 9) // only non-null ids + if sumIDs != expectedSum { + t.Fatalf("expected sum %v, got %v", expectedSum, sumIDs) + } + }) +} diff --git a/src/Backend/opti-sql-go/operators/aggr/sort.go b/src/Backend/opti-sql-go/operators/aggr/sort.go index abd1ad5..60d0cb5 100644 --- a/src/Backend/opti-sql-go/operators/aggr/sort.go +++ b/src/Backend/opti-sql-go/operators/aggr/sort.go @@ -1 +1,376 @@ package aggr + +import ( + "context" + "errors" + "fmt" + "io" + "math" + "opti-sql-go/Expr" + "opti-sql-go/operators" + "sort" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/compute" + "github.com/apache/arrow/go/v17/arrow/memory" +) + +// order by col asc, col 2 desc .... etc +var ( + _ = (operators.Operator)(&SortExec{}) + _ = (operators.Operator)(&TopKSortExec{}) +) + +type SortKey struct { + Expr Expr.Expression + Ascending bool // by default false -- DESC (highest values first -> smaller values) + NullFirst bool // by default false -- nulls last +} + +func NewSortKey(expr Expr.Expression, options ...bool) *SortKey { + var asc, nullF bool + switch len(options) { + case 2: + asc = options[0] + nullF = options[1] + case 1: + asc = options[0] + } + return &SortKey{ + Expr: expr, + Ascending: asc, + NullFirst: nullF, + } +} +func CombineSortKeys(sk ...*SortKey) []SortKey { + var res []SortKey + for _, s := range sk { + res = append(res, *s) + } + return res +} + +type SortExec struct { + child operators.Operator + schema *arrow.Schema + sortKeys []SortKey // resolves to columns + // internal book keeping + totalColumns []arrow.Array + consumedOffset uint64 + totalRows uint64 + consumed bool // did we finish reading all of the child record batches? + done bool // have we already produced all the sorted record batches? +} + +func NewSortExec(child operators.Operator, sortKeys []SortKey) (*SortExec, error) { + fmt.Printf("sorts Keys %v\n", sortKeys) + return &SortExec{ + child: child, + schema: child.Schema(), + sortKeys: sortKeys, + }, nil +} + +// for now read everything into memory and sort -- next steps will be to do external merge + +// n is the number of records we will return,sortExec will read in 2^16-1 column entries from its child, this is more efficient that trusting the caller to pass in a reasonable +// n so that we avoid small/frequent IO operations +func (s *SortExec) Next(n uint16) (*operators.RecordBatch, error) { + if s.done { + return nil, io.EOF + } + if !s.consumed { + allColumns := make([]arrow.Array, len(s.schema.Fields())) // concated columns + mem := memory.NewGoAllocator() + var count uint64 + for { + childBatch, err := s.child.Next(math.MaxUint16) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, err + } + for i := range childBatch.Columns { + if allColumns[i] == nil { + allColumns[i] = childBatch.Columns[i] + continue + } + largerArray, err := array.Concatenate([]arrow.Array{allColumns[i], childBatch.Columns[i]}, mem) + if err != nil { + return nil, err + } + allColumns[i] = largerArray + } + } + s.consumed = true + if len(allColumns) > 0 { + count = uint64(allColumns[0].Len()) + } + idx, err := sortBatches(&operators.RecordBatch{ + Schema: s.schema, + Columns: allColumns, + RowCount: count, + }, s.sortKeys) + if err != nil { + return nil, err + } + // now update all mappings + for i := range len(allColumns) { + arr, err := compute.TakeArray(context.TODO(), allColumns[i], idxToArrowArray(idx, mem)) + if err != nil { + return nil, err + } + allColumns[i] = arr + } + s.totalColumns = allColumns + s.totalRows = count + } + var readSize uint64 + remaining := s.totalRows - s.consumedOffset + if remaining < uint64(n) { + // if n is more than we have left just read up to remaining + readSize = uint64(remaining) + s.done = true + } else { + // remaining > n or remaining = n then just read n and return + readSize = uint64(n) + } + mem := memory.NewGoAllocator() + sortedColumns, err := s.consumeSortedBatch(readSize, mem) + if err != nil { + return nil, err + } + + return &operators.RecordBatch{ + Schema: s.schema, + Columns: sortedColumns, + RowCount: readSize, + }, nil +} +func (s *SortExec) Schema() *arrow.Schema { + return s.schema +} +func (s *SortExec) Close() error { + return s.child.Close() +} +func (s *SortExec) consumeSortedBatch(readsize uint64, mem memory.Allocator) ([]arrow.Array, error) { + ctx := context.TODO() + resultColumns := make([]arrow.Array, len(s.schema.Fields())) + offsetArray := genoffsetTakeIdx(s.consumedOffset, readsize, mem) + for i := range s.totalColumns { + sortArr := s.totalColumns[i] + arr, err := compute.TakeArray(ctx, sortArr, offsetArray) + if err != nil { + return nil, err + } + resultColumns[i] = arr + + } + s.consumedOffset += readsize + return resultColumns, nil +} + +/* +only sort and keep the top k elements in memory +*/ +type TopKSortExec struct { + child operators.Operator + schema *arrow.Schema + done bool + sortKeys []SortKey // resolves to columns + k uint16 // top k +} + +func NewTopKSortExec(child operators.Operator, sortKeys []SortKey, k uint16) (*TopKSortExec, error) { + fmt.Printf("sort keys %v\n", sortKeys) + return &TopKSortExec{ + child: child, + schema: child.Schema(), + sortKeys: sortKeys, + k: k, + }, nil +} + +// for now read everything into memory and sort -- next steps will be to do external merge +func (t *TopKSortExec) Next(n uint16) (*operators.RecordBatch, error) { + if t.done { + return nil, io.EOF + } + return nil, nil +} +func (t *TopKSortExec) Schema() *arrow.Schema { + return t.schema +} +func (t *TopKSortExec) Close() error { + return t.child.Close() +} + +/* +shared functions +*/ +func sortBatches(fullRC *operators.RecordBatch, sortKeys []SortKey) ([]uint64, error) { + keyColumns := make([]arrow.Array, len(sortKeys)) + for i, sk := range sortKeys { + arr, err := Expr.EvalExpression(sk.Expr, fullRC) + if err != nil { + return nil, fmt.Errorf("sort batches: failed to eval sort expression: %v", err) + } + keyColumns[i] = arr + } + idVector := make([]uint64, fullRC.RowCount) + for i := 0; uint64(i) < fullRC.RowCount; i++ { + idVector[i] = uint64(i) + } + sortIndexVector(idVector, keyColumns, sortKeys) + return idVector, nil +} + +// sortIndexVector sorts idVec based on keyColumns + sortKeys. +// keyColumns[i] corresponds to sortKeys[i]. +func sortIndexVector(idVec []uint64, keyColumns []arrow.Array, sortKeys []SortKey) { + sort.Slice(idVec, func(a, b int) bool { + i := idVec[a] + j := idVec[b] + + // lexicographic: go through each sort key + for k, col := range keyColumns { + sk := sortKeys[k] + cmp := compareArrowValues(col, i, j) + + if cmp == 0 { + continue // equal → move to next key + } + + if sk.Ascending { + return cmp < 0 + } else { + return cmp > 0 + } + } + + // completely equal for all keys + return false + }) +} + +func compareArrowValues(col arrow.Array, i, j uint64) int { + // Handle nulls (treat as lowest value for now) + if col.IsNull(int(i)) && col.IsNull(int(j)) { + return 0 + } + if col.IsNull(int(i)) { + return -1 + } + if col.IsNull(int(j)) { + return 1 + } + + switch arr := col.(type) { + + case *array.String: + vi := arr.Value(int(i)) + vj := arr.Value(int(j)) + switch { + case vi < vj: + return -1 + case vi > vj: + return 1 + default: + return 0 + } + + case *array.Int8: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Int16: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Int32: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Int64: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Uint8: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Uint16: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Uint32: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Uint64: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Float32: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareFloat(vi, vj) + + case *array.Float64: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareFloat(vi, vj) + + case *array.Boolean: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + if vi == vj { + return 0 + } + if !vi && vj { + return -1 + } + return 1 + + default: + panic("unsupported Arrow type in compareArrowValues") + } +} + +func compareNumeric[T int64 | int32 | int16 | int8 | uint64 | uint32 | uint16 | uint8](a, b T) int { + switch { + case a < b: + return -1 + case a > b: + return 1 + default: + return 0 + } +} + +func compareFloat[T float32 | float64](a, b T) int { + switch { + case a < b: + return -1 + case a > b: + return 1 + default: + return 0 + } +} +func idxToArrowArray(v []uint64, mem memory.Allocator) arrow.Array { + // turn to array first + b := array.NewUint64Builder(mem) + for _, val := range v { + b.Append(val) + } + arr := b.NewArray() + return arr +} +func genoffsetTakeIdx(offset, size uint64, mem memory.Allocator) arrow.Array { + b := array.NewUint64Builder(mem) + for i := range size { + b.Append(offset + i) + } + arr := b.NewArray() + return arr +} diff --git a/src/Backend/opti-sql-go/operators/aggr/sort_test.go b/src/Backend/opti-sql-go/operators/aggr/sort_test.go index b919b31..95754c8 100644 --- a/src/Backend/opti-sql-go/operators/aggr/sort_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/sort_test.go @@ -1,7 +1,613 @@ package aggr -import "testing" +import ( + "context" + "errors" + "fmt" + "io" + "opti-sql-go/Expr" + "opti-sql-go/operators" + "opti-sql-go/operators/project" + "testing" -func TestSort(t *testing.T) { + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/compute" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/go-jose/go-jose/v4/testutils/require" +) + +func buildAggTestRecordBatch(t *testing.T) *operators.RecordBatch { + names, cols := generateAggTestColumns() + mem := memory.NewGoAllocator() + + arrowCols := make([]arrow.Array, len(cols)) + fields := make([]arrow.Field, len(cols)) + + for i, col := range cols { + switch v := col.(type) { + + case []int32: + b := array.NewInt32Builder(mem) + defer b.Release() + b.AppendValues(v, nil) + arrowCols[i] = b.NewArray() + + case []string: + b := array.NewStringBuilder(mem) + defer b.Release() + b.AppendValues(v, nil) + arrowCols[i] = b.NewArray() + + case []float64: + b := array.NewFloat64Builder(mem) + defer b.Release() + b.AppendValues(v, nil) + arrowCols[i] = b.NewArray() + + default: + t.Fatalf("unsupported type in generateAggTestColumns") + } + + fields[i] = arrow.Field{Name: names[i], Type: arrowCols[i].DataType()} + } + + return &operators.RecordBatch{ + Schema: arrow.NewSchema(fields, nil), + Columns: arrowCols, + RowCount: uint64(len(cols[0].([]int32))), + } +} + +func TestSortInit(t *testing.T) { // Simple passing test + t.Run("sort Exec init", func(t *testing.T) { + proj := aggProject() + sortExec, err := NewSortExec(proj, nil) + if err != nil { + t.Fatal(err) + } + if !sortExec.Schema().Equal(proj.Schema()) { + t.Fatalf("expected schema %v, got %v", proj.Schema(), sortExec.schema) + } + sortExec.done = true + _, err = sortExec.Next(100) + if err != io.EOF { + t.Fatalf("expected io.EOF error on done sortExec but got %v", err) + } + if sortExec.Close() != nil { + t.Fatalf("expected nil error on close but got %v", sortExec.Close()) + } + + }) + t.Run("SortKey options", func(t *testing.T) { + proj := aggProject() + _, err := NewSortExec(proj, []SortKey{*NewSortKey(col("-"), false, false)}) + if err != nil { + t.Fatal(err) + } + + }) + t.Run("tok k sort exec init", func(t *testing.T) { + proj := aggProject() + topKVal := 5 + topK, err := NewTopKSortExec(proj, nil, uint16(topKVal)) + if err != nil { + t.Fatal(err) + } + if !topK.Schema().Equal(proj.Schema()) { + t.Fatalf("expected schema %v, got %v", proj.Schema(), topK.schema) + } + if topK.k != 5 { + t.Fatalf("expected %v for top k but got %v", topKVal, topK.k) + } + topK.done = true + _, err = topK.Next(100) + if err != io.EOF { + t.Fatalf("expected io.EOF error on done topK but got %v", err) + } + if topK.Close() != nil { + t.Fatalf("expected nil error on close but got %v", topK.Close()) + } + + }) +} + +func TestBasicSortExpr(t *testing.T) { + t.Run("Sort", func(t *testing.T) { + proj := aggProject() + nameExpr := Expr.NewColumnResolve("name") + nameSK := NewSortKey(nameExpr, true) + ageExpr := Expr.NewColumnResolve("age") + ageSK := NewSortKey(ageExpr, false) + _, err := NewSortExec(proj, CombineSortKeys(nameSK, ageSK)) + if err != nil { + t.Fatalf("unexpected error from NewSortExec : %v\n", err) + } + //t.Logf("%v\n", sortExec) + }) + t.Run("Basic Next operation", func(t *testing.T) { + proj := aggProject() + nameExpr := Expr.NewColumnResolve("name") + nameSK := NewSortKey(nameExpr, true) + ageExpr := Expr.NewColumnResolve("age") + ageSK := NewSortKey(ageExpr, false) + sortExec, err := NewSortExec(proj, CombineSortKeys(ageSK, nameSK)) + if err != nil { + t.Fatalf("unexpected error from NewSortExec : %v\n", err) + } + for { + sortedBatch, err := sortExec.Next(5) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Fatalf("unexpected error from sortExec Next : %v\n", err) + } + fmt.Println(sortedBatch.PrettyPrint()) + } + }) +} +func TestFullSortOverNetwork(t *testing.T) { + t.Run("Full Sort of large file", func(t *testing.T) { + const fileName = "country_full.csv" + nr, err := project.NewStreamReader(fileName) + if err != nil { + t.Fatalf("failed to create s3 object: %v", err) + } + pj, err := project.NewProjectCSVLeaf(nr.Stream()) + if err != nil { + t.Fatalf("failed to create csv project source from s3 object: %v", err) + } + nameExpr := Expr.NewColumnResolve("name") + nameSK := NewSortKey(nameExpr, true) + sortExec, err := NewSortExec(pj, CombineSortKeys(nameSK)) + if err != nil { + t.Fatalf("unexpected error %v\n", err) + } + rc, err := sortExec.Next(10) + if err != nil { + t.Fatalf("unexpected error %v\n", err) + } + fmt.Println(rc.PrettyPrint()) + + }) + +} + +func TestFullSortExec_Next(t *testing.T) { + t.Parallel() + + t.Run("sort_age_DESC", func(t *testing.T) { + proj := aggProject() + + ageExpr := Expr.NewColumnResolve("age") + ageSK := NewSortKey(ageExpr, false) // DESC + + sortExec, err := NewSortExec(proj, CombineSortKeys(ageSK)) + require.NoError(t, err) + + batch, err := sortExec.Next(5) + require.NoError(t, err) + require.Equal(t, uint64(5), batch.RowCount) + + ages := batch.Columns[2].(*array.Int32) + got := []int32{ + ages.Value(0), + ages.Value(1), + ages.Value(2), + ages.Value(3), + ages.Value(4), + } + + expected := []int32{50, 48, 46, 45, 43} + for i, v := range expected { + if got[i] != v { + t.Fatalf("expected %v at index %d, but got %v", v, i, got[i]) + } + } + }) + + t.Run("sort_name_ASC", func(t *testing.T) { + proj := aggProject() + + nameExpr := Expr.NewColumnResolve("name") + nameSK := NewSortKey(nameExpr, true) + + sortExec, err := NewSortExec(proj, CombineSortKeys(nameSK)) + require.NoError(t, err) + + batch, err := sortExec.Next(3) + require.NoError(t, err) + + names := batch.Columns[1].(*array.String) + got := []string{ + names.Value(0), + names.Value(1), + names.Value(2), + } + + expected := []string{"Alice", "Bob", "Charlie"} + for i, v := range expected { + if got[i] != v { + t.Fatalf("expected %v at index %d, but got %v", v, i, got[i]) + } + } + }) +} + +// ----------------------------------------------------------------------------- +// TEST 2: sortIndexVector() +// ----------------------------------------------------------------------------- + +func TestSortIndexVector(t *testing.T) { + t.Parallel() + + mem := memory.NewGoAllocator() + + t.Run("single_key_int", func(t *testing.T) { + b := array.NewInt32Builder(mem) + b.AppendValues([]int32{30, 10, 20}, nil) + arr := b.NewArray() + defer arr.Release() + + keys := []arrow.Array{arr} + idVec := []uint64{0, 1, 2} + + sks := []SortKey{ + {Expr: nil, Ascending: true}, + } + + sortIndexVector(idVec, keys, sks) + + expected := []uint64{1, 2, 0} + for i, v := range expected { + if idVec[i] != v { + t.Fatalf("expected %v at index %d, but got %v", v, i, idVec[i]) + } + } + }) + + t.Run("single_key_string", func(t *testing.T) { + b := array.NewStringBuilder(mem) + b.AppendValues([]string{"Charlie", "Alice", "Bob"}, nil) + arr := b.NewArray() + defer arr.Release() + + keys := []arrow.Array{arr} + idVec := []uint64{0, 1, 2} + + sks := []SortKey{{Ascending: true}} + + sortIndexVector(idVec, keys, sks) + + expected := []uint64{1, 2, 0} + for i, v := range expected { + if idVec[i] != v { + t.Fatalf("expected %v at index %d, but got %v", v, i, idVec[i]) + } + } + }) +} + +// ----------------------------------------------------------------------------- +// TEST 3: compareArrowValues() +// ----------------------------------------------------------------------------- + +func TestCompareArrowValues(t *testing.T) { + t.Parallel() + + mem := memory.NewGoAllocator() + + t.Run("int", func(t *testing.T) { + b := array.NewInt32Builder(mem) + b.AppendValues([]int32{10, 20}, nil) + arr := b.NewArray() + defer arr.Release() + + require.Equal(t, -1, compareArrowValues(arr, 0, 1)) + require.Equal(t, 1, compareArrowValues(arr, 1, 0)) + require.Equal(t, 0, compareArrowValues(arr, 0, 0)) + }) + + t.Run("uint", func(t *testing.T) { + b := array.NewUint32Builder(mem) + b.AppendValues([]uint32{5, 7}, nil) + arr := b.NewArray() + defer arr.Release() + + require.Equal(t, -1, compareArrowValues(arr, 0, 1)) + require.Equal(t, 1, compareArrowValues(arr, 1, 0)) + }) + + t.Run("float", func(t *testing.T) { + b := array.NewFloat64Builder(mem) + b.AppendValues([]float64{1.5, 1.7}, nil) + arr := b.NewArray() + defer arr.Release() + + require.Equal(t, -1, compareArrowValues(arr, 0, 1)) + require.Equal(t, 1, compareArrowValues(arr, 1, 0)) + }) + + t.Run("string", func(t *testing.T) { + b := array.NewStringBuilder(mem) + b.AppendValues([]string{"a", "b"}, nil) + arr := b.NewArray() + defer arr.Release() + + require.Equal(t, -1, compareArrowValues(arr, 0, 1)) + require.Equal(t, 1, compareArrowValues(arr, 1, 0)) + }) + + t.Run("bool", func(t *testing.T) { + b := array.NewBooleanBuilder(mem) + b.AppendValues([]bool{false, true}, nil) + arr := b.NewArray() + defer arr.Release() + + require.Equal(t, -1, compareArrowValues(arr, 0, 1)) + require.Equal(t, 1, compareArrowValues(arr, 1, 0)) + }) +} +func TestCompareArrowValues_AllTypes(t *testing.T) { + mem := memory.NewGoAllocator() + + // helper to assert cmp result + assert := func(name string, got, want int) { + if got != want { + t.Fatalf("%s: expected %d, got %d", name, want, got) + } + } + + // ---- STRING ---- + strB := array.NewStringBuilder(mem) + strB.Append("apple") + strB.Append("banana") + strArr := strB.NewArray().(*array.String) + + assert("string lt", compareArrowValues(strArr, 0, 1), -1) + assert("string gt", compareArrowValues(strArr, 1, 0), 1) + assert("string eq", compareArrowValues(strArr, 0, 0), 0) + + strArr.Release() + strB.Release() + + // ---- INT TYPES ---- + int8Arr := buildInt8(mem, []int8{1, 3}) + assert("int8 lt", compareArrowValues(int8Arr, 0, 1), -1) + assert("int8 gt", compareArrowValues(int8Arr, 1, 0), 1) + assert("int8 eq", compareArrowValues(int8Arr, 0, 0), 0) + int8Arr.Release() + + int16Arr := buildInt16(mem, []int16{5, 2}) + assert("int16 gt", compareArrowValues(int16Arr, 0, 1), 1) + int16Arr.Release() + + int32Arr := buildInt32(mem, []int32{10, 10}) + assert("int32 eq", compareArrowValues(int32Arr, 0, 1), 0) + int32Arr.Release() + + int64Arr := buildInt64(mem, []int64{-5, 7}) + assert("int64 lt", compareArrowValues(int64Arr, 0, 1), -1) + int64Arr.Release() + + // ---- UINT TYPES ---- + u8Arr := buildUint8(mem, []uint8{9, 3}) + assert("uint8 gt", compareArrowValues(u8Arr, 0, 1), 1) + u8Arr.Release() + + u16Arr := buildUint16(mem, []uint16{3, 3}) + assert("uint16 eq", compareArrowValues(u16Arr, 0, 1), 0) + u16Arr.Release() + + u32Arr := buildUint32(mem, []uint32{3, 10}) + assert("uint32 lt", compareArrowValues(u32Arr, 0, 1), -1) + u32Arr.Release() + + u64Arr := buildUint64(mem, []uint64{100, 2}) + assert("uint64 gt", compareArrowValues(u64Arr, 0, 1), 1) + u64Arr.Release() + + // ---- FLOAT TYPES ---- + f32Arr := buildFloat32(mem, []float32{1.5, 1.5}) + assert("float32 eq", compareArrowValues(f32Arr, 0, 1), 0) + f32Arr.Release() + + f64Arr := buildFloat64(mem, []float64{-1.0, 2.3}) + assert("float64 lt", compareArrowValues(f64Arr, 0, 1), -1) + f64Arr.Release() + + // ---- BOOLEAN ---- + boolArr := buildBool(mem, []bool{false, true}) + assert("bool lt", compareArrowValues(boolArr, 0, 1), -1) + assert("bool gt", compareArrowValues(boolArr, 1, 0), 1) + assert("bool eq", compareArrowValues(boolArr, 1, 1), 0) + boolArr.Release() + + // ---- NULL CASES ---- + nullB := array.NewInt32Builder(mem) + nullB.AppendNull() + nullB.Append(10) + nullArr := nullB.NewArray().(*array.Int32) + + assert("null < value", compareArrowValues(nullArr, 0, 1), -1) + assert("value > null", compareArrowValues(nullArr, 1, 0), 1) + assert("null == null", compareArrowValues(nullArr, 0, 0), 0) + + nullArr.Release() + nullB.Release() + + // ---- UNSUPPORTED TYPE PANIC ---- + // Build a fixed-size binary array to trigger panic + fsb := array.NewFixedSizeBinaryBuilder(mem, &arrow.FixedSizeBinaryType{ByteWidth: 2}) + fsb.Append([]byte{1, 2}) + fsb.Append([]byte{3, 4}) + fsArr := fsb.NewArray() + + didPanic := false + func() { + defer func() { + if recover() != nil { + didPanic = true + } + }() + _ = compareArrowValues(fsArr, 0, 1) + }() + if !didPanic { + t.Fatalf("expected panic for unsupported Arrow type") + } + + fsArr.Release() + fsb.Release() +} +func buildInt8(mem memory.Allocator, vals []int8) *array.Int8 { + b := array.NewInt8Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Int8) + b.Release() + return arr +} + +func buildInt16(mem memory.Allocator, vals []int16) *array.Int16 { + b := array.NewInt16Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Int16) + b.Release() + return arr +} + +func buildInt32(mem memory.Allocator, vals []int32) *array.Int32 { + b := array.NewInt32Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Int32) + b.Release() + return arr +} + +func buildInt64(mem memory.Allocator, vals []int64) *array.Int64 { + b := array.NewInt64Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Int64) + b.Release() + return arr +} + +func buildUint8(mem memory.Allocator, vals []uint8) *array.Uint8 { + b := array.NewUint8Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Uint8) + b.Release() + return arr +} + +func buildUint16(mem memory.Allocator, vals []uint16) *array.Uint16 { + b := array.NewUint16Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Uint16) + b.Release() + return arr +} + +func buildUint32(mem memory.Allocator, vals []uint32) *array.Uint32 { + b := array.NewUint32Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Uint32) + b.Release() + return arr +} + +func buildUint64(mem memory.Allocator, vals []uint64) *array.Uint64 { + b := array.NewUint64Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Uint64) + b.Release() + return arr +} + +func buildFloat32(mem memory.Allocator, vals []float32) *array.Float32 { + b := array.NewFloat32Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Float32) + b.Release() + return arr +} + +func buildFloat64(mem memory.Allocator, vals []float64) *array.Float64 { + b := array.NewFloat64Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Float64) + b.Release() + return arr +} + +func buildBool(mem memory.Allocator, vals []bool) *array.Boolean { + b := array.NewBooleanBuilder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Boolean) + b.Release() + return arr +} + +func TestBasicTopKSortExpr(t *testing.T) { + t.Run("TopK Sort", func(t *testing.T) { + proj := aggProject() + nameExpr := Expr.NewColumnResolve("name") + nameSK := NewSortKey(nameExpr, true) + ageExpr := Expr.NewColumnResolve("age") + ageSK := NewSortKey(ageExpr, false) + sortExec, err := NewTopKSortExec(proj, CombineSortKeys(nameSK, ageSK), 5) + if err != nil { + t.Fatalf("unexpected error from NewTopKSortExec : %v\n", err) + } + t.Logf("%v\n", sortExec) + + }) +} + +func TestOne(t *testing.T) { + v := compute.GetExecCtx(context.Background()) + names := v.Registry.GetFunctionNames() + for i, name := range names { + fmt.Printf("%d: %v\n", i, name) + } + /* + mem := memory.NewGoAllocator() + floatB := array.NewFloat64Builder(mem) + floatB.AppendValues([]float64{10.5, 20.3, 30.1, 40.7, 50.2}, []bool{true, true, true, true, true}) + pos := array.NewInt32Builder(mem) + pos.AppendValues([]int32{1, 3, 4}, []bool{true, true, true}) + + dat, err := compute.Take(context.TODO(), *compute.DefaultTakeOptions(), compute.NewDatum(floatB.NewArray()), compute.NewDatum(pos.NewArray())) + if err != nil { + t.Fatalf("Take failed: %v", err) + } + array, ok := dat.(*compute.ArrayDatum) + if !ok { + t.Logf("expected an array to be returned but got something else %T\n", dat) + } + t.Logf("data: %v\n", array.MakeArray()) + */ } diff --git a/src/Backend/opti-sql-go/operators/aggr/sum.go b/src/Backend/opti-sql-go/operators/aggr/sum.go deleted file mode 100644 index abd1ad5..0000000 --- a/src/Backend/opti-sql-go/operators/aggr/sum.go +++ /dev/null @@ -1 +0,0 @@ -package aggr diff --git a/src/Backend/opti-sql-go/operators/aggr/sum_test.go b/src/Backend/opti-sql-go/operators/aggr/sum_test.go deleted file mode 100644 index 485b9bb..0000000 --- a/src/Backend/opti-sql-go/operators/aggr/sum_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package aggr - -import "testing" - -func TestSum(t *testing.T) { - // Simple passing test -} diff --git a/src/Backend/opti-sql-go/operators/filter/filter.go b/src/Backend/opti-sql-go/operators/filter/filter.go index ddd8c1b..6c30c8f 100644 --- a/src/Backend/opti-sql-go/operators/filter/filter.go +++ b/src/Backend/opti-sql-go/operators/filter/filter.go @@ -41,11 +41,11 @@ func (f *FilterExec) Next(n uint16) (*operators.RecordBatch, error) { if f.done { return nil, io.EOF } - batch, err := f.input.Next(n) + childBatch, err := f.input.Next(n) if err != nil { return nil, err } - booleanMask, err := Expr.EvalExpression(f.predicate, batch) + booleanMask, err := Expr.EvalExpression(f.predicate, childBatch) if err != nil { return nil, err } @@ -53,21 +53,20 @@ func (f *FilterExec) Next(n uint16) (*operators.RecordBatch, error) { if !ok { return nil, errors.New("predicate did not evaluate to boolean array") } - filteredCol := make([]arrow.Array, len(batch.Columns)) - for i, col := range batch.Columns { - filteredCol[i], err = applyBooleanMask(col, boolArr) + filteredCol := make([]arrow.Array, len(childBatch.Columns)) + for i, col := range childBatch.Columns { + filteredCol[i], err = ApplyBooleanMask(col, boolArr) if err != nil { return nil, err } } + booleanMask.Release() // release old columns - for _, c := range batch.Columns { - c.Release() - } + operators.ReleaseArrays(childBatch.Columns) size := uint64(filteredCol[0].Len()) return &operators.RecordBatch{ - Schema: batch.Schema, + Schema: childBatch.Schema, Columns: filteredCol, RowCount: size, }, nil @@ -80,7 +79,7 @@ func (f *FilterExec) Close() error { return f.input.Close() } -func applyBooleanMask(col arrow.Array, mask *array.Boolean) (arrow.Array, error) { +func ApplyBooleanMask(col arrow.Array, mask *array.Boolean) (arrow.Array, error) { datum, err := compute.Filter( context.TODO(), compute.NewDatum(col), diff --git a/src/Backend/opti-sql-go/operators/project/custom.go b/src/Backend/opti-sql-go/operators/project/custom.go index e36fa0c..0816600 100644 --- a/src/Backend/opti-sql-go/operators/project/custom.go +++ b/src/Backend/opti-sql-go/operators/project/custom.go @@ -73,6 +73,35 @@ func (ms *InMemorySource) withFields(names ...string) error { ms.columns = cols return nil } +func NewInMemoryProjectExecFromArrays(names []string, arrays []arrow.Array) (*InMemorySource, error) { + if len(names) != len(arrays) { + return nil, operators.ErrInvalidSchema("number of column names and arrays do not match") + } + + fields := make([]arrow.Field, len(names)) + fieldToColIdx := make(map[string]int, len(names)) + + for i, arr := range arrays { + if arr == nil { + return nil, operators.ErrInvalidSchema(fmt.Sprintf("nil array for column %s", names[i])) + } + + fields[i] = arrow.Field{ + Name: names[i], + Type: arr.DataType(), + Nullable: true, // Arrow arrays may have null bitmaps + } + + fieldToColIdx[names[i]] = i + } + + return &InMemorySource{ + schema: arrow.NewSchema(fields, nil), + columns: arrays, + fieldToColIDx: fieldToColIdx, + }, nil +} + func (ms *InMemorySource) Next(n uint16) (*operators.RecordBatch, error) { if len(ms.columns) == 0 || ms.pos >= uint16(ms.columns[0].Len()) { return nil, io.EOF // EOF diff --git a/src/Backend/opti-sql-go/operators/project/parquet.go b/src/Backend/opti-sql-go/operators/project/parquet.go index 94b6e1d..42d5c14 100644 --- a/src/Backend/opti-sql-go/operators/project/parquet.go +++ b/src/Backend/opti-sql-go/operators/project/parquet.go @@ -22,12 +22,10 @@ var ( ) type ParquetSource struct { - // existing fields schema *arrow.Schema projectionPushDown []string // columns to project up reader pqarrow.RecordReader - // for internal reading - done bool // if set to true always return io.EOF + done bool // if set to true always return io.EOF } func NewParquetSource(r parquet.ReaderAtSeeker) (*ParquetSource, error) { @@ -45,7 +43,7 @@ func NewParquetSource(r parquet.ReaderAtSeeker) (*ParquetSource, error) { arrowReader, err := pqarrow.NewFileReader( filerReader, - pqarrow.ArrowReadProperties{Parallel: true, BatchSize: int64(Config.Batch.Size)}, // TODO: Read in from config for this stuff + pqarrow.ArrowReadProperties{Parallel: true, BatchSize: int64(Config.Batch.Size)}, allocator, ) if err != nil { @@ -84,7 +82,7 @@ func NewParquetSourcePushDown(r parquet.ReaderAtSeeker, columns []string) (*Parq arrowReader, err := pqarrow.NewFileReader( filerReader, - pqarrow.ArrowReadProperties{Parallel: true, BatchSize: int64(Config.Batch.Size)}, // TODO: Read in from config for this stuff + pqarrow.ArrowReadProperties{Parallel: true, BatchSize: int64(Config.Batch.Size)}, allocator, ) if err != nil { diff --git a/src/Backend/opti-sql-go/operators/project/projectExec.go b/src/Backend/opti-sql-go/operators/project/projectExec.go index 9d93d96..033a58c 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec.go @@ -20,7 +20,7 @@ var ( ) type ProjectExec struct { - child operators.Operator + input operators.Operator outputschema arrow.Schema expr []Expr.Expression done bool @@ -60,7 +60,7 @@ func NewProjectExec(input operators.Operator, exprs []Expr.Expression) (*Project outputschema := arrow.NewSchema(fields, nil) // return new exec return &ProjectExec{ - child: input, + input: input, outputschema: *outputschema, expr: exprs, }, nil @@ -73,7 +73,7 @@ func (p *ProjectExec) Next(n uint16) (*operators.RecordBatch, error) { return nil, io.EOF } - childBatch, err := p.child.Next(n) + childBatch, err := p.input.Next(n) if err != nil { return nil, err } @@ -94,9 +94,7 @@ func (p *ProjectExec) Next(n uint16) (*operators.RecordBatch, error) { outPutCols[i] = arr arr.Retain() } - for _, c := range childBatch.Columns { - c.Release() - } + operators.ReleaseArrays(childBatch.Columns) return &operators.RecordBatch{ Schema: &p.outputschema, Columns: outPutCols, @@ -104,7 +102,7 @@ func (p *ProjectExec) Next(n uint16) (*operators.RecordBatch, error) { }, nil } func (p *ProjectExec) Close() error { - return p.child.Close() + return p.input.Close() } func (p *ProjectExec) Schema() *arrow.Schema { return &p.outputschema diff --git a/src/Backend/opti-sql-go/operators/record.go b/src/Backend/opti-sql-go/operators/record.go index 60f695b..6678ef4 100644 --- a/src/Backend/opti-sql-go/operators/record.go +++ b/src/Backend/opti-sql-go/operators/record.go @@ -24,7 +24,7 @@ type Operator interface { type RecordBatch struct { Schema *arrow.Schema Columns []arrow.Array - RowCount uint64 // TODO: update to actually use this, in all operators + RowCount uint64 // } type SchemaBuilder struct { @@ -129,6 +129,7 @@ func (rb *RecordBatch) ColumnByName(name string) (arrow.Array, error) { } return rb.Columns[indices[0]], nil } + func (rbb *RecordBatchBuilder) GenIntArray(values ...int) arrow.Array { mem := memory.NewGoAllocator() builder := array.NewInt32Builder(mem) @@ -289,3 +290,122 @@ func (rbb *RecordBatchBuilder) GenLargeBinaryArray(values ...[]byte) arrow.Array } return builder.NewArray() } +func ReleaseArrays(a []arrow.Array) { + for _, col := range a { + if col != nil { + col.Release() + } + } +} + +func (rb *RecordBatch) PrettyPrint() string { + if rb == nil { + return "" + } + + // ------------------------------- + // 1. Extract column names + // ------------------------------- + colNames := make([]string, len(rb.Schema.Fields())) + for i, f := range rb.Schema.Fields() { + colNames[i] = f.Name + } + + // ------------------------------- + // 2. Extract rows into [][]string + // ------------------------------- + rows := make([][]string, rb.RowCount) + for r := 0; r < int(rb.RowCount); r++ { + row := make([]string, len(rb.Columns)) + for c, arr := range rb.Columns { + row[c] = formatValue(arr, r) + } + rows[r] = row + } + + // ------------------------------- + // 3. Compute column widths + // ------------------------------- + colWidths := make([]int, len(colNames)) + for i, name := range colNames { + colWidths[i] = len(name) + } + for _, row := range rows { + for i, v := range row { + if len(v) > colWidths[i] { + colWidths[i] = len(v) + } + } + } + + // ------------------------------- + // 4. Build horizontal border line + // ------------------------------- + border := "+" + for _, w := range colWidths { + border += strings.Repeat("-", w+2) + "+" + } + + // ------------------------------- + // 5. Build the final output + // ------------------------------- + var b strings.Builder + + b.WriteString(border + "\n") + + // Header + b.WriteString("|") + for i, name := range colNames { + b.WriteString(" " + padRight(name, colWidths[i]) + " |") + } + b.WriteString("\n") + + b.WriteString(border + "\n") + + // Rows + for _, row := range rows { + b.WriteString("|") + for i, v := range row { + b.WriteString(" " + padRight(v, colWidths[i]) + " |") + } + b.WriteString("\n") + } + + b.WriteString(border) + + return b.String() +} + +// ------------------------------- +// Helper Functions +// ------------------------------- + +func padRight(s string, width int) string { + if len(s) >= width { + return s + } + return s + strings.Repeat(" ", width-len(s)) +} + +func formatValue(arr arrow.Array, row int) string { + if arr.IsNull(row) { + return "NULL" + } + + switch col := arr.(type) { + case *array.Int32: + return fmt.Sprintf("%d", col.Value(row)) + case *array.Int64: + return fmt.Sprintf("%d", col.Value(row)) + case *array.Float32: + return fmt.Sprintf("%g", col.Value(row)) + case *array.Float64: + return fmt.Sprintf("%g", col.Value(row)) + case *array.String: + return col.Value(row) + case *array.Boolean: + return fmt.Sprintf("%t", col.Value(row)) + default: + return "" + } +} diff --git a/src/Backend/opti-sql-go/operators/test/t1_test.go b/src/Backend/opti-sql-go/operators/test/t1_test.go new file mode 100644 index 0000000..a571421 --- /dev/null +++ b/src/Backend/opti-sql-go/operators/test/t1_test.go @@ -0,0 +1,3 @@ +package test + +// test for all operators together