diff --git a/client/column_decoder.go b/client/column_decoder.go index 3367911..d6cea20 100644 --- a/client/column_decoder.go +++ b/client/column_decoder.go @@ -100,6 +100,18 @@ func (decoder *Int32ArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataTyp // +---------------+-----------------+-------------+ // | byte | list[byte] | list[int32] | // +---------------+-----------------+-------------+ + + if positionCount == 0 { + switch dataType { + case INT32, DATE: + return NewIntColumn(0, 0, nil, []int32{}) + case FLOAT: + return NewFloatColumn(0, 0, nil, []float32{}) + default: + return nil, fmt.Errorf("invalid data type: %v", dataType) + } + } + nullIndicators, err := deserializeNullIndicators(reader, positionCount) if err != nil { return nil, err @@ -166,6 +178,18 @@ func (decoder *Int64ArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataTyp // +---------------+-----------------+-------------+ // | byte | list[byte] | list[int64] | // +---------------+-----------------+-------------+ + + if positionCount == 0 { + switch dataType { + case INT64, TIMESTAMP: + return NewLongColumn(0, 0, nil, []int64{}) + case DOUBLE: + return NewDoubleColumn(0, 0, nil, []float64{}) + default: + return nil, fmt.Errorf("invalid data type: %v", dataType) + } + } + nullIndicators, err := deserializeNullIndicators(reader, positionCount) if err != nil { return nil, err @@ -212,6 +236,11 @@ func (decoder *ByteArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataType if dataType != BOOLEAN { return nil, fmt.Errorf("invalid data type: %v", dataType) } + + if positionCount == 0 { + return NewBooleanColumn(0, 0, nil, []bool{}) + } + nullIndicators, err := deserializeNullIndicators(reader, positionCount) if err != nil { return nil, err @@ -245,6 +274,11 @@ func (decoder *BinaryArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataTy if TEXT != dataType { return nil, fmt.Errorf("invalid data type: %v", dataType) } + + if positionCount == 0 { + return NewBinaryColumn(0, 0, nil, []*Binary{}) + } + nullIndicators, err := deserializeNullIndicators(reader, positionCount) if err != nil { return nil, err @@ -259,12 +293,17 @@ func (decoder *BinaryArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataTy if err != nil { return nil, err } - value := make([]byte, length) - _, err = reader.Read(value) - if err != nil { - return nil, err + + if length == 0 { + values[i] = NewBinary([]byte{}) + } else { + value := make([]byte, length) + _, err = reader.Read(value) + if err != nil { + return nil, err + } + values[i] = NewBinary(value) } - values[i] = NewBinary(value) } return NewBinaryColumn(0, positionCount, nullIndicators, values) } diff --git a/client/column_decoder_test.go b/client/column_decoder_test.go new file mode 100644 index 0000000..dc5d433 --- /dev/null +++ b/client/column_decoder_test.go @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package client + +import ( + "bytes" + "encoding/binary" + "testing" +) + +func buildNullIndicatorBytes(nulls []bool) []byte { + var buf bytes.Buffer + hasNull := false + for _, n := range nulls { + if n { + hasNull = true + break + } + } + if !hasNull { + buf.WriteByte(0) + return buf.Bytes() + } + buf.WriteByte(1) + packed := make([]byte, (len(nulls)+7)/8) + for i, n := range nulls { + if n { + packed[i/8] |= 0b10000000 >> (uint(i) % 8) + } + } + buf.Write(packed) + return buf.Bytes() +} + +func TestBinaryArrayColumnDecoder_EmptyString(t *testing.T) { + var buf bytes.Buffer + buf.Write(buildNullIndicatorBytes([]bool{false})) + _ = binary.Write(&buf, binary.BigEndian, int32(0)) + + col, err := (&BinaryArrayColumnDecoder{}).ReadColumn(bytes.NewReader(buf.Bytes()), TEXT, 1) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 1 { + t.Fatalf("expected positionCount=1, got %d", col.GetPositionCount()) + } + if col.IsNull(0) { + t.Fatal("row 0 should not be null") + } + val, err := col.GetBinary(0) + if err != nil { + t.Fatalf("GetBinary(0) failed: %v", err) + } + if len(val.values) != 0 { + t.Fatalf("expected empty string, got %q", string(val.values)) + } +} + +func TestBinaryArrayColumnDecoder_NullThenEmptyString(t *testing.T) { + var buf bytes.Buffer + buf.Write(buildNullIndicatorBytes([]bool{true, false})) + _ = binary.Write(&buf, binary.BigEndian, int32(0)) + + col, err := (&BinaryArrayColumnDecoder{}).ReadColumn(bytes.NewReader(buf.Bytes()), TEXT, 2) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if !col.IsNull(0) { + t.Error("row 0 should be null") + } + if col.IsNull(1) { + t.Error("row 1 should not be null") + } + val, err := col.GetBinary(1) + if err != nil { + t.Fatalf("GetBinary(1) failed: %v", err) + } + if len(val.values) != 0 { + t.Fatalf("expected empty string, got %q", string(val.values)) + } +} + +func TestBinaryArrayColumnDecoder_WithNull(t *testing.T) { + var buf bytes.Buffer + buf.Write(buildNullIndicatorBytes([]bool{false, true, false})) + writeText := func(s string) { + _ = binary.Write(&buf, binary.BigEndian, int32(len(s))) + buf.WriteString(s) + } + writeText("hello") + writeText("world") + + col, err := (&BinaryArrayColumnDecoder{}).ReadColumn(bytes.NewReader(buf.Bytes()), TEXT, 3) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.IsNull(0) { + t.Error("row 0 should not be null") + } + if v, _ := col.GetBinary(0); string(v.values) != "hello" { + t.Errorf("row 0: expected \"hello\", got %q", string(v.values)) + } + if !col.IsNull(1) { + t.Error("row 1 should be null") + } + if col.IsNull(2) { + t.Error("row 2 should not be null") + } + if v, _ := col.GetBinary(2); string(v.values) != "world" { + t.Errorf("row 2: expected \"world\", got %q", string(v.values)) + } +} + +func TestInt64ArrayColumnDecoder_WithNull(t *testing.T) { + var buf bytes.Buffer + buf.Write(buildNullIndicatorBytes([]bool{false, true, false})) + _ = binary.Write(&buf, binary.BigEndian, int64(100)) + _ = binary.Write(&buf, binary.BigEndian, int64(200)) + + col, err := (&Int64ArrayColumnDecoder{}).ReadColumn(bytes.NewReader(buf.Bytes()), INT64, 3) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.IsNull(0) { + t.Error("row 0 should not be null") + } + if v, _ := col.GetLong(0); v != 100 { + t.Errorf("row 0: expected 100, got %d", v) + } + if !col.IsNull(1) { + t.Error("row 1 should be null") + } + if col.IsNull(2) { + t.Error("row 2 should not be null") + } + if v, _ := col.GetLong(2); v != 200 { + t.Errorf("row 2: expected 200, got %d", v) + } +} + +func TestColumnDecoder_ZeroPositionCount(t *testing.T) { + empty := func() *bytes.Reader { return bytes.NewReader([]byte{}) } + + t.Run("Int32ArrayColumnDecoder", func(t *testing.T) { + col, err := (&Int32ArrayColumnDecoder{}).ReadColumn(empty(), INT32, 0) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 0 { + t.Errorf("expected positionCount=0, got %d", col.GetPositionCount()) + } + }) + + t.Run("Int64ArrayColumnDecoder", func(t *testing.T) { + col, err := (&Int64ArrayColumnDecoder{}).ReadColumn(empty(), INT64, 0) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 0 { + t.Errorf("expected positionCount=0, got %d", col.GetPositionCount()) + } + }) + + t.Run("ByteArrayColumnDecoder", func(t *testing.T) { + col, err := (&ByteArrayColumnDecoder{}).ReadColumn(empty(), BOOLEAN, 0) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 0 { + t.Errorf("expected positionCount=0, got %d", col.GetPositionCount()) + } + }) + + t.Run("BinaryArrayColumnDecoder", func(t *testing.T) { + col, err := (&BinaryArrayColumnDecoder{}).ReadColumn(empty(), TEXT, 0) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 0 { + t.Errorf("expected positionCount=0, got %d", col.GetPositionCount()) + } + }) +} diff --git a/client/errors.go b/client/errors.go index 66ead4f..0f1aabb 100644 --- a/client/errors.go +++ b/client/errors.go @@ -21,7 +21,6 @@ package client import ( "fmt" - "github.com/apache/iotdb-client-go/common" ) diff --git a/client/session.go b/client/session.go index 98a9fee..fbf1011 100644 --- a/client/session.go +++ b/client/session.go @@ -519,10 +519,15 @@ func (s *Session) ExecuteQueryStatement(sql string, timeoutMs *int64) (*SessionD request.SessionId = s.sessionId request.StatementId = s.requestStatementId resp, err = s.client.ExecuteQueryStatementV2(context.Background(), &request) - if statusErr := VerifySuccess(resp.Status); statusErr == nil { - return NewSessionDataSet(sql, resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize) - } else { - return nil, statusErr + if err == nil { + if resp == nil { + return nil, fmt.Errorf("received nil response after reconnect") + } + if statusErr := VerifySuccess(resp.Status); statusErr == nil { + return NewSessionDataSet(sql, resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize) + } else { + return nil, statusErr + } } } return nil, err @@ -545,10 +550,15 @@ func (s *Session) ExecuteAggregationQuery(paths []string, aggregations []common. if s.reconnect() { request.SessionId = s.sessionId resp, err = s.client.ExecuteAggregationQueryV2(context.Background(), &request) - if statusErr := VerifySuccess(resp.Status); statusErr == nil { - return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize) - } else { - return nil, statusErr + if err == nil { + if resp == nil { + return nil, fmt.Errorf("received nil response after reconnect") + } + if statusErr := VerifySuccess(resp.Status); statusErr == nil { + return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize) + } else { + return nil, statusErr + } } } return nil, err @@ -572,10 +582,15 @@ func (s *Session) ExecuteAggregationQueryWithLegalNodes(paths []string, aggregat if s.reconnect() { request.SessionId = s.sessionId resp, err = s.client.ExecuteAggregationQueryV2(context.Background(), &request) - if statusErr := VerifySuccess(resp.Status); statusErr == nil { - return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize) - } else { - return nil, statusErr + if err == nil { + if resp == nil { + return nil, fmt.Errorf("received nil response after reconnect") + } + if statusErr := VerifySuccess(resp.Status); statusErr == nil { + return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize) + } else { + return nil, statusErr + } } } return nil, err diff --git a/client/sessionpool.go b/client/sessionpool.go index 48b322a..6054e7d 100644 --- a/client/sessionpool.go +++ b/client/sessionpool.go @@ -83,7 +83,10 @@ func (spool *SessionPool) GetSession() (session Session, err error) { } default: config := spool.config - session, err := spool.ConstructSession(config) + session, err = spool.ConstructSession(config) + if err != nil { + <-spool.sem + } return session, err } case <-time.After(time.Millisecond * time.Duration(spool.waitToGetSessionTimeoutInMs)): @@ -137,7 +140,12 @@ func getClusterSessionConfig(config *PoolConfig) *ClusterConfig { } func (spool *SessionPool) PutBack(session Session) { - if session.trans.IsOpen() { + defer func() { + if r := recover(); r != nil { + session.Close() + } + }() + if session.trans != nil && session.trans.IsOpen() { spool.ch <- session } <-spool.sem