From 81fe8335c2aa89a9bfe28b47485553b04e46213c Mon Sep 17 00:00:00 2001 From: shuwenwei <55970239+shuwenwei@users.noreply.github.com> Date: Thu, 26 Feb 2026 17:13:48 +0800 Subject: [PATCH 1/4] Fix multiple issues in SessionPool and PooledTableSession (#153) --- client/sessionpool.go | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/client/sessionpool.go b/client/sessionpool.go index 48b322a..a05cc4e 100644 --- a/client/sessionpool.go +++ b/client/sessionpool.go @@ -83,7 +83,10 @@ func (spool *SessionPool) GetSession() (session Session, err error) { } default: config := spool.config - session, err := spool.ConstructSession(config) + session, err = spool.ConstructSession(config) + if err != nil { + <-spool.sem + } return session, err } case <-time.After(time.Millisecond * time.Duration(spool.waitToGetSessionTimeoutInMs)): @@ -137,12 +140,33 @@ func getClusterSessionConfig(config *PoolConfig) *ClusterConfig { } func (spool *SessionPool) PutBack(session Session) { - if session.trans.IsOpen() { + defer func() { + if r := recover(); r != nil { + session.Close() + } + }() + if session.trans != nil && session.trans.IsOpen() { spool.ch <- session } <-spool.sem } +func (spool *SessionPool) dropSession(session Session) { + defer func() { + if e := recover(); e != nil { + if session.trans != nil && session.trans.IsOpen() { + session.Close() + } + } + }() + err := session.Close() + if err != nil { + log.Println("Failed to close session ", session) + } + <-spool.sem +} + +>>>>>>> f00cf99 (Fix multiple issues in SessionPool and PooledTableSession (#153)) func (spool *SessionPool) Close() { close(spool.ch) for s := range spool.ch { From 35e9dc4496c3ed6c5e849a2ed1af1025bef4c53a Mon Sep 17 00:00:00 2001 From: Haonan Date: Thu, 12 Feb 2026 18:03:50 +0800 Subject: [PATCH 2/4] Call VerifySuccess before return to user (#151) * Move VerifySuccess * fix missing code * fix missing code * fix copilot review --- client/errors.go | 29 ++- client/session.go | 252 ++++++++++++------- client/sessionpool.go | 16 -- client/utils.go | 8 +- example/session_example.go | 16 +- example/session_pool/session_pool_example.go | 17 +- test/e2e/e2e_test.go | 29 +-- 7 files changed, 208 insertions(+), 159 deletions(-) diff --git a/client/errors.go b/client/errors.go index 2c54bde..0f1aabb 100644 --- a/client/errors.go +++ b/client/errors.go @@ -20,28 +20,33 @@ package client import ( - "bytes" + "fmt" "github.com/apache/iotdb-client-go/common" ) +// ExecutionError represents an error returned by the server via TSStatus. +// It is NOT a connection error and should not cause session drops. +type ExecutionError struct { + Code int32 + Message string +} + +func (e *ExecutionError) Error() string { + if e.Message != "" { + return fmt.Sprintf("error code: %d, message: %v", e.Code, e.Message) + } + return fmt.Sprintf("error code: %d", e.Code) +} + type BatchError struct { statuses []*common.TSStatus + Message string } func (e *BatchError) Error() string { - buff := bytes.Buffer{} - for _, status := range e.statuses { - buff.WriteString(*status.Message + ";") - } - return buff.String() + return e.Message } func (e *BatchError) GetStatuses() []*common.TSStatus { return e.statuses } - -func NewBatchError(statuses []*common.TSStatus) *BatchError { - return &BatchError{ - statuses: statuses, - } -} diff --git a/client/session.go b/client/session.go index fff4cf0..98a9fee 100644 --- a/client/session.go +++ b/client/session.go @@ -186,14 +186,17 @@ func (s *Session) Close() error { *return *error: correctness of operation */ -func (s *Session) SetStorageGroup(storageGroupId string) (r *common.TSStatus, err error) { - r, err = s.client.SetStorageGroup(context.Background(), s.sessionId, storageGroupId) +func (s *Session) SetStorageGroup(storageGroupId string) error { + r, err := s.client.SetStorageGroup(context.Background(), s.sessionId, storageGroupId) if err != nil && r == nil { if s.reconnect() { r, err = s.client.SetStorageGroup(context.Background(), s.sessionId, storageGroupId) } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } /* @@ -203,14 +206,17 @@ func (s *Session) SetStorageGroup(storageGroupId string) (r *common.TSStatus, er *return *error: correctness of operation */ -func (s *Session) DeleteStorageGroup(storageGroupId string) (r *common.TSStatus, err error) { - r, err = s.client.DeleteStorageGroups(context.Background(), s.sessionId, []string{storageGroupId}) +func (s *Session) DeleteStorageGroup(storageGroupId string) error { + r, err := s.client.DeleteStorageGroups(context.Background(), s.sessionId, []string{storageGroupId}) if err != nil && r == nil { if s.reconnect() { r, err = s.client.DeleteStorageGroups(context.Background(), s.sessionId, []string{storageGroupId}) } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } /* @@ -220,14 +226,17 @@ func (s *Session) DeleteStorageGroup(storageGroupId string) (r *common.TSStatus, *return *error: correctness of operation */ -func (s *Session) DeleteStorageGroups(storageGroupIds ...string) (r *common.TSStatus, err error) { - r, err = s.client.DeleteStorageGroups(context.Background(), s.sessionId, storageGroupIds) +func (s *Session) DeleteStorageGroups(storageGroupIds ...string) error { + r, err := s.client.DeleteStorageGroups(context.Background(), s.sessionId, storageGroupIds) if err != nil && r == nil { if s.reconnect() { r, err = s.client.DeleteStorageGroups(context.Background(), s.sessionId, storageGroupIds) } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } /* @@ -240,9 +249,11 @@ func (s *Session) DeleteStorageGroups(storageGroupIds ...string) (r *common.TSSt *return *error: correctness of operation */ -func (s *Session) CreateTimeseries(path string, dataType TSDataType, encoding TSEncoding, compressor TSCompressionType, attributes map[string]string, tags map[string]string) (r *common.TSStatus, err error) { - request := rpc.TSCreateTimeseriesReq{SessionId: s.sessionId, Path: path, DataType: int32(dataType), Encoding: int32(encoding), - Compressor: int32(compressor), Attributes: attributes, Tags: tags} +func (s *Session) CreateTimeseries(path string, dataType TSDataType, encoding TSEncoding, compressor TSCompressionType, attributes map[string]string, tags map[string]string) error { + request := rpc.TSCreateTimeseriesReq{ + SessionId: s.sessionId, Path: path, DataType: int32(dataType), Encoding: int32(encoding), + Compressor: int32(compressor), Attributes: attributes, Tags: tags, + } status, err := s.client.CreateTimeseries(context.Background(), &request) if err != nil && status == nil { if s.reconnect() { @@ -250,7 +261,10 @@ func (s *Session) CreateTimeseries(path string, dataType TSDataType, encoding TS status, err = s.client.CreateTimeseries(context.Background(), &request) } } - return status, err + if err != nil { + return err + } + return VerifySuccess(status) } /* @@ -265,7 +279,7 @@ func (s *Session) CreateTimeseries(path string, dataType TSDataType, encoding TS *return *error: correctness of operation */ -func (s *Session) CreateAlignedTimeseries(prefixPath string, measurements []string, dataTypes []TSDataType, encodings []TSEncoding, compressors []TSCompressionType, measurementAlias []string) (r *common.TSStatus, err error) { +func (s *Session) CreateAlignedTimeseries(prefixPath string, measurements []string, dataTypes []TSDataType, encodings []TSEncoding, compressors []TSCompressionType, measurementAlias []string) error { destTypes := make([]int32, len(dataTypes)) for i, t := range dataTypes { destTypes[i] = int32(t) @@ -297,7 +311,10 @@ func (s *Session) CreateAlignedTimeseries(prefixPath string, measurements []stri status, err = s.client.CreateAlignedTimeseries(context.Background(), &request) } } - return status, err + if err != nil { + return err + } + return VerifySuccess(status) } /* @@ -310,7 +327,7 @@ func (s *Session) CreateAlignedTimeseries(prefixPath string, measurements []stri *return *error: correctness of operation */ -func (s *Session) CreateMultiTimeseries(paths []string, dataTypes []TSDataType, encodings []TSEncoding, compressors []TSCompressionType) (r *common.TSStatus, err error) { +func (s *Session) CreateMultiTimeseries(paths []string, dataTypes []TSDataType, encodings []TSEncoding, compressors []TSCompressionType) error { destTypes := make([]int32, len(dataTypes)) for i, t := range dataTypes { destTypes[i] = int32(t) @@ -326,9 +343,11 @@ func (s *Session) CreateMultiTimeseries(paths []string, dataTypes []TSDataType, destCompressions[i] = int32(e) } - request := rpc.TSCreateMultiTimeseriesReq{SessionId: s.sessionId, Paths: paths, DataTypes: destTypes, - Encodings: destEncodings, Compressors: destCompressions} - r, err = s.client.CreateMultiTimeseries(context.Background(), &request) + request := rpc.TSCreateMultiTimeseriesReq{ + SessionId: s.sessionId, Paths: paths, DataTypes: destTypes, + Encodings: destEncodings, Compressors: destCompressions, + } + r, err := s.client.CreateMultiTimeseries(context.Background(), &request) if err != nil && r == nil { if s.reconnect() { @@ -337,7 +356,10 @@ func (s *Session) CreateMultiTimeseries(paths []string, dataTypes []TSDataType, } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } /* @@ -347,14 +369,17 @@ func (s *Session) CreateMultiTimeseries(paths []string, dataTypes []TSDataType, *return *error: correctness of operation */ -func (s *Session) DeleteTimeseries(paths []string) (r *common.TSStatus, err error) { - r, err = s.client.DeleteTimeseries(context.Background(), s.sessionId, paths) +func (s *Session) DeleteTimeseries(paths []string) error { + r, err := s.client.DeleteTimeseries(context.Background(), s.sessionId, paths) if err != nil && r == nil { if s.reconnect() { r, err = s.client.DeleteTimeseries(context.Background(), s.sessionId, paths) } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } /* @@ -366,16 +391,19 @@ func (s *Session) DeleteTimeseries(paths []string) (r *common.TSStatus, err erro *return *error: correctness of operation */ -func (s *Session) DeleteData(paths []string, startTime int64, endTime int64) (r *common.TSStatus, err error) { +func (s *Session) DeleteData(paths []string, startTime int64, endTime int64) error { request := rpc.TSDeleteDataReq{SessionId: s.sessionId, Paths: paths, StartTime: startTime, EndTime: endTime} - r, err = s.client.DeleteData(context.Background(), &request) + r, err := s.client.DeleteData(context.Background(), &request) if err != nil && r == nil { if s.reconnect() { request.SessionId = s.sessionId r, err = s.client.DeleteData(context.Background(), &request) } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } /* @@ -388,17 +416,22 @@ func (s *Session) DeleteData(paths []string, startTime int64, endTime int64) (r *return *error: correctness of operation */ -func (s *Session) InsertStringRecord(deviceId string, measurements []string, values []string, timestamp int64) (r *common.TSStatus, err error) { - request := rpc.TSInsertStringRecordReq{SessionId: s.sessionId, PrefixPath: deviceId, Measurements: measurements, - Values: values, Timestamp: timestamp} - r, err = s.client.InsertStringRecord(context.Background(), &request) +func (s *Session) InsertStringRecord(deviceId string, measurements []string, values []string, timestamp int64) error { + request := rpc.TSInsertStringRecordReq{ + SessionId: s.sessionId, PrefixPath: deviceId, Measurements: measurements, + Values: values, Timestamp: timestamp, + } + r, err := s.client.InsertStringRecord(context.Background(), &request) if err != nil && r == nil { if s.reconnect() { request.SessionId = s.sessionId r, err = s.client.InsertStringRecord(context.Background(), &request) } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } func (s *Session) GetTimeZone() (string, error) { @@ -409,11 +442,17 @@ func (s *Session) GetTimeZone() (string, error) { return resp.TimeZone, nil } -func (s *Session) SetTimeZone(timeZone string) (r *common.TSStatus, err error) { +func (s *Session) SetTimeZone(timeZone string) error { request := rpc.TSSetTimeZoneReq{SessionId: s.sessionId, TimeZone: timeZone} - r, err = s.client.SetTimeZone(context.Background(), &request) + r, err := s.client.SetTimeZone(context.Background(), &request) + if err != nil { + return err + } + if err := VerifySuccess(r); err != nil { + return err + } s.config.TimeZone = timeZone - return r, err + return nil } func (s *Session) ExecuteStatementWithContext(ctx context.Context, sql string) (*SessionDataSet, error) { @@ -444,7 +483,7 @@ func (s *Session) ExecuteStatement(sql string) (*SessionDataSet, error) { return s.ExecuteStatementWithContext(context.Background(), sql) } -func (s *Session) ExecuteNonQueryStatement(sql string) (r *common.TSStatus, err error) { +func (s *Session) ExecuteNonQueryStatement(sql string) error { request := rpc.TSExecuteStatementReq{ SessionId: s.sessionId, Statement: sql, @@ -460,8 +499,10 @@ func (s *Session) ExecuteNonQueryStatement(sql string) (r *common.TSStatus, err resp, err = s.client.ExecuteStatementV2(context.Background(), &request) } } - - return resp.Status, err + if err != nil { + return err + } + return VerifySuccess(resp.Status) } func (s *Session) ExecuteQueryStatement(sql string, timeoutMs *int64) (*SessionDataSet, error) { @@ -560,12 +601,12 @@ func (s *Session) genTSInsertRecordReq(deviceId string, time int64, return request, nil } -func (s *Session) InsertRecord(deviceId string, measurements []string, dataTypes []TSDataType, values []interface{}, timestamp int64) (r *common.TSStatus, err error) { +func (s *Session) InsertRecord(deviceId string, measurements []string, dataTypes []TSDataType, values []interface{}, timestamp int64) error { request, err := s.genTSInsertRecordReq(deviceId, timestamp, measurements, dataTypes, values, false) if err != nil { - return nil, err + return err } - r, err = s.client.InsertRecord(context.Background(), request) + r, err := s.client.InsertRecord(context.Background(), request) if err != nil && r == nil { if s.reconnect() { @@ -574,15 +615,18 @@ func (s *Session) InsertRecord(deviceId string, measurements []string, dataTypes } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } -func (s *Session) InsertAlignedRecord(deviceId string, measurements []string, dataTypes []TSDataType, values []interface{}, timestamp int64) (r *common.TSStatus, err error) { +func (s *Session) InsertAlignedRecord(deviceId string, measurements []string, dataTypes []TSDataType, values []interface{}, timestamp int64) error { request, err := s.genTSInsertRecordReq(deviceId, timestamp, measurements, dataTypes, values, true) if err != nil { - return nil, err + return err } - r, err = s.client.InsertRecord(context.Background(), request) + r, err := s.client.InsertRecord(context.Background(), request) if err != nil && r == nil { if s.reconnect() { @@ -591,7 +635,10 @@ func (s *Session) InsertAlignedRecord(deviceId string, measurements []string, da } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } type deviceData struct { @@ -620,11 +667,11 @@ func (d *deviceData) Swap(i, j int) { // InsertRecordsOfOneDevice Insert multiple rows, which can reduce the overhead of network. This method is just like jdbc // executeBatch, we pack some insert request in batch and send them to server. If you want improve // your performance, please see insertTablet method -// Each row is independent, which could have different deviceId, time, number of measurements -func (s *Session) InsertRecordsOfOneDevice(deviceId string, timestamps []int64, measurementsSlice [][]string, dataTypesSlice [][]TSDataType, valuesSlice [][]interface{}, sorted bool) (r *common.TSStatus, err error) { +// Each row is independent, which could have different insertTargetName, time, number of measurements +func (s *Session) InsertRecordsOfOneDevice(deviceId string, timestamps []int64, measurementsSlice [][]string, dataTypesSlice [][]TSDataType, valuesSlice [][]interface{}, sorted bool) error { length := len(timestamps) if len(measurementsSlice) != length || len(dataTypesSlice) != length || len(valuesSlice) != length { - return nil, errors.New("timestamps, measurementsSlice and valuesSlice's size should be equal") + return errors.New("timestamps, measurementsSlice and valuesSlice's size should be equal") } if !sorted { @@ -636,10 +683,11 @@ func (s *Session) InsertRecordsOfOneDevice(deviceId string, timestamps []int64, }) } + var err error valuesList := make([][]byte, length) for i := 0; i < length; i++ { if valuesList[i], err = valuesToBytes(dataTypesSlice[i], valuesSlice[i], measurementsSlice[i]); err != nil { - return nil, err + return err } } @@ -651,7 +699,7 @@ func (s *Session) InsertRecordsOfOneDevice(deviceId string, timestamps []int64, ValuesList: valuesList, } - r, err = s.client.InsertRecordsOfOneDevice(context.Background(), request) + r, err := s.client.InsertRecordsOfOneDevice(context.Background(), request) if err != nil && r == nil { if s.reconnect() { @@ -660,13 +708,16 @@ func (s *Session) InsertRecordsOfOneDevice(deviceId string, timestamps []int64, } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } -func (s *Session) InsertAlignedRecordsOfOneDevice(deviceId string, timestamps []int64, measurementsSlice [][]string, dataTypesSlice [][]TSDataType, valuesSlice [][]interface{}, sorted bool) (r *common.TSStatus, err error) { +func (s *Session) InsertAlignedRecordsOfOneDevice(deviceId string, timestamps []int64, measurementsSlice [][]string, dataTypesSlice [][]TSDataType, valuesSlice [][]interface{}, sorted bool) error { length := len(timestamps) if len(measurementsSlice) != length || len(dataTypesSlice) != length || len(valuesSlice) != length { - return nil, errors.New("timestamps, measurementsSlice and valuesSlice's size should be equal") + return errors.New("timestamps, measurementsSlice and valuesSlice's size should be equal") } if !sorted { @@ -678,10 +729,11 @@ func (s *Session) InsertAlignedRecordsOfOneDevice(deviceId string, timestamps [] }) } + var err error valuesList := make([][]byte, length) for i := 0; i < length; i++ { if valuesList[i], err = valuesToBytes(dataTypesSlice[i], valuesSlice[i], measurementsSlice[i]); err != nil { - return nil, err + return err } } var isAligned = true @@ -694,7 +746,7 @@ func (s *Session) InsertAlignedRecordsOfOneDevice(deviceId string, timestamps [] IsAligned: &isAligned, } - r, err = s.client.InsertRecordsOfOneDevice(context.Background(), request) + r, err := s.client.InsertRecordsOfOneDevice(context.Background(), request) if err != nil && r == nil { if s.reconnect() { @@ -703,7 +755,10 @@ func (s *Session) InsertAlignedRecordsOfOneDevice(deviceId string, timestamps [] } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } /* @@ -719,36 +774,44 @@ func (s *Session) InsertAlignedRecordsOfOneDevice(deviceId string, timestamps [] * */ func (s *Session) InsertRecords(deviceIds []string, measurements [][]string, dataTypes [][]TSDataType, values [][]interface{}, - timestamps []int64) (r *common.TSStatus, err error) { + timestamps []int64, +) error { request, err := s.genInsertRecordsReq(deviceIds, measurements, dataTypes, values, timestamps, false) if err != nil { - return nil, err + return err } else { - r, err = s.client.InsertRecords(context.Background(), request) + r, err := s.client.InsertRecords(context.Background(), request) if err != nil && r == nil { if s.reconnect() { request.SessionId = s.sessionId r, err = s.client.InsertRecords(context.Background(), request) } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } } func (s *Session) InsertAlignedRecords(deviceIds []string, measurements [][]string, dataTypes [][]TSDataType, values [][]interface{}, - timestamps []int64) (r *common.TSStatus, err error) { + timestamps []int64, +) error { request, err := s.genInsertRecordsReq(deviceIds, measurements, dataTypes, values, timestamps, true) if err != nil { - return nil, err + return err } else { - r, err = s.client.InsertRecords(context.Background(), request) + r, err := s.client.InsertRecords(context.Background(), request) if err != nil && r == nil { if s.reconnect() { request.SessionId = s.sessionId r, err = s.client.InsertRecords(context.Background(), request) } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } } @@ -757,63 +820,72 @@ func (s *Session) InsertAlignedRecords(deviceIds []string, measurements [][]stri *params *tablets: []*client.Tablet, list of tablets */ -func (s *Session) InsertTablets(tablets []*Tablet, sorted bool) (r *common.TSStatus, err error) { +func (s *Session) InsertTablets(tablets []*Tablet, sorted bool) error { if !sorted { for _, t := range tablets { if err := t.Sort(); err != nil { - return nil, err + return err } } } request, err := s.genInsertTabletsReq(tablets, false) if err != nil { - return nil, err + return err } - r, err = s.client.InsertTablets(context.Background(), request) + r, err := s.client.InsertTablets(context.Background(), request) if err != nil && r == nil { if s.reconnect() { request.SessionId = s.sessionId r, err = s.client.InsertTablets(context.Background(), request) } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } -func (s *Session) InsertAlignedTablets(tablets []*Tablet, sorted bool) (r *common.TSStatus, err error) { +func (s *Session) InsertAlignedTablets(tablets []*Tablet, sorted bool) error { if !sorted { for _, t := range tablets { if err := t.Sort(); err != nil { - return nil, err + return err } } } request, err := s.genInsertTabletsReq(tablets, true) if err != nil { - return nil, err + return err } - r, err = s.client.InsertTablets(context.Background(), request) + r, err := s.client.InsertTablets(context.Background(), request) if err != nil && r == nil { if s.reconnect() { request.SessionId = s.sessionId r, err = s.client.InsertTablets(context.Background(), request) } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } -func (s *Session) ExecuteBatchStatement(inserts []string) (r *common.TSStatus, err error) { +func (s *Session) ExecuteBatchStatement(inserts []string) error { request := rpc.TSExecuteBatchStatementReq{ SessionId: s.sessionId, Statements: inserts, } - r, err = s.client.ExecuteBatchStatement(context.Background(), &request) + r, err := s.client.ExecuteBatchStatement(context.Background(), &request) if err != nil && r == nil { if s.reconnect() { request.SessionId = s.sessionId r, err = s.client.ExecuteBatchStatement(context.Background(), &request) } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } func (s *Session) ExecuteRawDataQuery(paths []string, startTime int64, endTime int64) (*SessionDataSet, error) { @@ -1020,18 +1092,18 @@ func valuesToBytes(dataTypes []TSDataType, values []interface{}, measurementName return buff.Bytes(), nil } -func (s *Session) InsertTablet(tablet *Tablet, sorted bool) (r *common.TSStatus, err error) { +func (s *Session) InsertTablet(tablet *Tablet, sorted bool) error { if !sorted { if err := tablet.Sort(); err != nil { - return nil, err + return err } } request, err := s.genTSInsertTabletReq(tablet, false) if err != nil { - return nil, err + return err } - r, err = s.client.InsertTablet(context.Background(), request) + r, err := s.client.InsertTablet(context.Background(), request) if err != nil && r == nil { if s.reconnect() { @@ -1040,21 +1112,24 @@ func (s *Session) InsertTablet(tablet *Tablet, sorted bool) (r *common.TSStatus, } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } -func (s *Session) InsertAlignedTablet(tablet *Tablet, sorted bool) (r *common.TSStatus, err error) { +func (s *Session) InsertAlignedTablet(tablet *Tablet, sorted bool) error { if !sorted { if err := tablet.Sort(); err != nil { - return nil, err + return err } } request, err := s.genTSInsertTabletReq(tablet, true) if err != nil { - return nil, err + return err } - r, err = s.client.InsertTablet(context.Background(), request) + r, err := s.client.InsertTablet(context.Background(), request) if err != nil && r == nil { if s.reconnect() { @@ -1063,7 +1138,10 @@ func (s *Session) InsertAlignedTablet(tablet *Tablet, sorted bool) (r *common.TS } } - return r, err + if err != nil { + return err + } + return VerifySuccess(r) } func (s *Session) genTSInsertTabletReq(tablet *Tablet, isAligned bool) (*rpc.TSInsertTabletReq, error) { diff --git a/client/sessionpool.go b/client/sessionpool.go index a05cc4e..6054e7d 100644 --- a/client/sessionpool.go +++ b/client/sessionpool.go @@ -151,22 +151,6 @@ func (spool *SessionPool) PutBack(session Session) { <-spool.sem } -func (spool *SessionPool) dropSession(session Session) { - defer func() { - if e := recover(); e != nil { - if session.trans != nil && session.trans.IsOpen() { - session.Close() - } - } - }() - err := session.Close() - if err != nil { - log.Println("Failed to close session ", session) - } - <-spool.sem -} - ->>>>>>> f00cf99 (Fix multiple issues in SessionPool and PooledTableSession (#153)) func (spool *SessionPool) Close() { close(spool.ch) for s := range spool.ch { diff --git a/client/utils.go b/client/utils.go index 41cc783..0f05dac 100644 --- a/client/utils.go +++ b/client/utils.go @@ -124,7 +124,7 @@ func verifySuccesses(statuses []*common.TSStatus) error { } errMsg := buff.String() if len(errMsg) > 0 { - return NewBatchError(statuses) + return &BatchError{statuses, errMsg} } return nil } @@ -141,11 +141,11 @@ func VerifySuccess(status *common.TSStatus) error { return nil } if status.Code != SuccessStatus { + msg := "" if status.Message != nil { - return fmt.Errorf("error code: %d, message: %v", status.Code, *status.Message) - } else { - return fmt.Errorf("error code: %d", status.Code) + msg = *status.Message } + return &ExecutionError{Code: status.Code, Message: msg} } return nil } diff --git a/example/session_example.go b/example/session_example.go index 1b31514..8ebe918 100644 --- a/example/session_example.go +++ b/example/session_example.go @@ -465,9 +465,9 @@ func deleteData() { func insertTablet() { if tablet, err := createTablet(12); err == nil { - status, err := session.InsertTablet(tablet, false) + err = session.InsertTablet(tablet, false) tablet.Reset() - checkError(status, err) + checkError(err) } else { log.Fatal(err) } @@ -475,9 +475,9 @@ func insertTablet() { func insertAlignedTablet() { if tablet, err := createTablet(12); err == nil { - status, err := session.InsertAlignedTablet(tablet, false) + err = session.InsertAlignedTablet(tablet, false) tablet.Reset() - checkError(status, err) + checkError(err) } else { log.Fatal(err) } @@ -642,14 +642,8 @@ func executeBatchStatement() { } } -func checkError(status *common.TSStatus, err error) { +func checkError(err error) { if err != nil { log.Fatal(err) } - - if status != nil { - if err = client.VerifySuccess(status); err != nil { - log.Println(err) - } - } } diff --git a/example/session_pool/session_pool_example.go b/example/session_pool/session_pool_example.go index 6f52ffe..2dfb5bb 100644 --- a/example/session_pool/session_pool_example.go +++ b/example/session_pool/session_pool_example.go @@ -22,7 +22,6 @@ package main import ( "flag" "fmt" - "github.com/apache/iotdb-client-go/common" "log" "math/rand" "strings" @@ -410,9 +409,9 @@ func insertTablet() { defer sessionPool.PutBack(session) if err == nil { if tablet, err := createTablet(12); err == nil { - status, err := session.InsertTablet(tablet, false) + err := session.InsertTablet(tablet, false) tablet.Reset() - checkError(status, err) + checkError(err) } else { log.Fatal(err) } @@ -425,9 +424,9 @@ func insertAlignedTablet() { defer sessionPool.PutBack(session) if err == nil { if tablet, err := createTablet(12); err == nil { - status, err := session.InsertAlignedTablet(tablet, false) + err := session.InsertAlignedTablet(tablet, false) tablet.Reset() - checkError(status, err) + checkError(err) } else { log.Fatal(err) } @@ -725,14 +724,8 @@ func printDataSet2(sds *client.SessionDataSet) { } } -func checkError(status *common.TSStatus, err error) { +func checkError(err error) { if err != nil { log.Fatal(err) } - - if status != nil { - if err = client.VerifySuccess(status); err != nil { - log.Println(err) - } - } } diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index a225546..be83bbf 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -28,8 +28,6 @@ import ( "testing" "time" - "github.com/apache/iotdb-client-go/common" - "github.com/apache/iotdb-client-go/client" "github.com/stretchr/testify/suite" ) @@ -61,20 +59,17 @@ func (s *e2eTestSuite) TearDownSuite() { } func (s *e2eTestSuite) SetupTest() { - r, err := s.session.SetStorageGroup("root.tsg1") - s.checkError(r, err) + err := s.session.SetStorageGroup("root.tsg1") + s.checkError(err) } func (s *e2eTestSuite) TearDownTest() { - r, err := s.session.DeleteStorageGroup("root.tsg1") - s.checkError(r, err) + err := s.session.DeleteStorageGroup("root.tsg1") + s.checkError(err) } -func (s *e2eTestSuite) checkError(status *common.TSStatus, err error) { +func (s *e2eTestSuite) checkError(err error) { s.Require().NoError(err) - if status != nil { - s.Require().NoError(client.VerifySuccess(status)) - } } func (s *e2eTestSuite) Test_NonQuery() { @@ -174,7 +169,7 @@ func (s *e2eTestSuite) Test_InsertRecordsWithWrongType() { values = [][]interface{}{{100.0, true}, {"aaa"}} timestamp = []int64{1, 2} ) - _, err := s.session.InsertRecords(deviceId, measurements, dataTypes, values, timestamp) + err := s.session.InsertRecords(deviceId, measurements, dataTypes, values, timestamp) assert := s.Require() assert.NotNil(err) assert.Equal("measurement s1 values[0] 100(float64) must be bool", err.Error()) @@ -255,8 +250,8 @@ func (s *e2eTestSuite) Test_InsertAlignedTablet() { var timeseries = []string{"root.ln.device1.**"} s.session.DeleteTimeseries(timeseries) if tablet, err := createTablet(12); err == nil { - status, err := s.session.InsertAlignedTablet(tablet, false) - s.checkError(status, err) + err := s.session.InsertAlignedTablet(tablet, false) + s.checkError(err) tablet.Reset() } else { log.Fatal(err) @@ -277,8 +272,8 @@ func (s *e2eTestSuite) Test_InsertAlignedTabletWithNilValue() { var timeseries = []string{"root.ln.device1.**"} s.session.DeleteTimeseries(timeseries) if tablet, err := createTabletWithNil(12); err == nil { - status, err := s.session.InsertAlignedTablet(tablet, false) - s.checkError(status, err) + err := s.session.InsertAlignedTablet(tablet, false) + s.checkError(err) tablet.Reset() } else { log.Fatal(err) @@ -499,8 +494,8 @@ func (s *e2eTestSuite) Test_QueryAllDataType() { tablet.SetValueAt("string", 9, 0) tablet.RowSize = 1 - r, err := s.session.InsertAlignedTablet(tablet, true) - s.checkError(r, err) + err = s.session.InsertAlignedTablet(tablet, true) + s.checkError(err) sessionDataSet, err := s.session.ExecuteQueryStatement("select s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 from root.tsg1.d1 limit 1", nil) for { From 5268914a9eb6e130d0bbba8dd609ca8aa958eaba Mon Sep 17 00:00:00 2001 From: Zane <44780287+betterlmy@users.noreply.github.com> Date: Thu, 12 Mar 2026 19:00:51 +0800 Subject: [PATCH 3/4] =?UTF-8?q?fix:=20fix=20EOF=20error=20when=20decoding?= =?UTF-8?q?=20columns=20with=20empty=20string=20or=20zero=20po=E2=80=A6=20?= =?UTF-8?q?(#155)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: fix EOF error when decoding columns with empty string or zero positionCount * Update column_decoder_test.go 增加license header * fix: return error when resp is nil after reconnect * fix: GetCurrentRowTime returns time.Time to avoid precision ambiguity --- client/column_decoder.go | 49 ++++++++- client/column_decoder_test.go | 200 ++++++++++++++++++++++++++++++++++ client/session.go | 39 +++++-- client/sessiondataset.go | 4 + 4 files changed, 275 insertions(+), 17 deletions(-) create mode 100644 client/column_decoder_test.go diff --git a/client/column_decoder.go b/client/column_decoder.go index 3367911..d6cea20 100644 --- a/client/column_decoder.go +++ b/client/column_decoder.go @@ -100,6 +100,18 @@ func (decoder *Int32ArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataTyp // +---------------+-----------------+-------------+ // | byte | list[byte] | list[int32] | // +---------------+-----------------+-------------+ + + if positionCount == 0 { + switch dataType { + case INT32, DATE: + return NewIntColumn(0, 0, nil, []int32{}) + case FLOAT: + return NewFloatColumn(0, 0, nil, []float32{}) + default: + return nil, fmt.Errorf("invalid data type: %v", dataType) + } + } + nullIndicators, err := deserializeNullIndicators(reader, positionCount) if err != nil { return nil, err @@ -166,6 +178,18 @@ func (decoder *Int64ArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataTyp // +---------------+-----------------+-------------+ // | byte | list[byte] | list[int64] | // +---------------+-----------------+-------------+ + + if positionCount == 0 { + switch dataType { + case INT64, TIMESTAMP: + return NewLongColumn(0, 0, nil, []int64{}) + case DOUBLE: + return NewDoubleColumn(0, 0, nil, []float64{}) + default: + return nil, fmt.Errorf("invalid data type: %v", dataType) + } + } + nullIndicators, err := deserializeNullIndicators(reader, positionCount) if err != nil { return nil, err @@ -212,6 +236,11 @@ func (decoder *ByteArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataType if dataType != BOOLEAN { return nil, fmt.Errorf("invalid data type: %v", dataType) } + + if positionCount == 0 { + return NewBooleanColumn(0, 0, nil, []bool{}) + } + nullIndicators, err := deserializeNullIndicators(reader, positionCount) if err != nil { return nil, err @@ -245,6 +274,11 @@ func (decoder *BinaryArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataTy if TEXT != dataType { return nil, fmt.Errorf("invalid data type: %v", dataType) } + + if positionCount == 0 { + return NewBinaryColumn(0, 0, nil, []*Binary{}) + } + nullIndicators, err := deserializeNullIndicators(reader, positionCount) if err != nil { return nil, err @@ -259,12 +293,17 @@ func (decoder *BinaryArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataTy if err != nil { return nil, err } - value := make([]byte, length) - _, err = reader.Read(value) - if err != nil { - return nil, err + + if length == 0 { + values[i] = NewBinary([]byte{}) + } else { + value := make([]byte, length) + _, err = reader.Read(value) + if err != nil { + return nil, err + } + values[i] = NewBinary(value) } - values[i] = NewBinary(value) } return NewBinaryColumn(0, positionCount, nullIndicators, values) } diff --git a/client/column_decoder_test.go b/client/column_decoder_test.go new file mode 100644 index 0000000..dc5d433 --- /dev/null +++ b/client/column_decoder_test.go @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package client + +import ( + "bytes" + "encoding/binary" + "testing" +) + +func buildNullIndicatorBytes(nulls []bool) []byte { + var buf bytes.Buffer + hasNull := false + for _, n := range nulls { + if n { + hasNull = true + break + } + } + if !hasNull { + buf.WriteByte(0) + return buf.Bytes() + } + buf.WriteByte(1) + packed := make([]byte, (len(nulls)+7)/8) + for i, n := range nulls { + if n { + packed[i/8] |= 0b10000000 >> (uint(i) % 8) + } + } + buf.Write(packed) + return buf.Bytes() +} + +func TestBinaryArrayColumnDecoder_EmptyString(t *testing.T) { + var buf bytes.Buffer + buf.Write(buildNullIndicatorBytes([]bool{false})) + _ = binary.Write(&buf, binary.BigEndian, int32(0)) + + col, err := (&BinaryArrayColumnDecoder{}).ReadColumn(bytes.NewReader(buf.Bytes()), TEXT, 1) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 1 { + t.Fatalf("expected positionCount=1, got %d", col.GetPositionCount()) + } + if col.IsNull(0) { + t.Fatal("row 0 should not be null") + } + val, err := col.GetBinary(0) + if err != nil { + t.Fatalf("GetBinary(0) failed: %v", err) + } + if len(val.values) != 0 { + t.Fatalf("expected empty string, got %q", string(val.values)) + } +} + +func TestBinaryArrayColumnDecoder_NullThenEmptyString(t *testing.T) { + var buf bytes.Buffer + buf.Write(buildNullIndicatorBytes([]bool{true, false})) + _ = binary.Write(&buf, binary.BigEndian, int32(0)) + + col, err := (&BinaryArrayColumnDecoder{}).ReadColumn(bytes.NewReader(buf.Bytes()), TEXT, 2) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if !col.IsNull(0) { + t.Error("row 0 should be null") + } + if col.IsNull(1) { + t.Error("row 1 should not be null") + } + val, err := col.GetBinary(1) + if err != nil { + t.Fatalf("GetBinary(1) failed: %v", err) + } + if len(val.values) != 0 { + t.Fatalf("expected empty string, got %q", string(val.values)) + } +} + +func TestBinaryArrayColumnDecoder_WithNull(t *testing.T) { + var buf bytes.Buffer + buf.Write(buildNullIndicatorBytes([]bool{false, true, false})) + writeText := func(s string) { + _ = binary.Write(&buf, binary.BigEndian, int32(len(s))) + buf.WriteString(s) + } + writeText("hello") + writeText("world") + + col, err := (&BinaryArrayColumnDecoder{}).ReadColumn(bytes.NewReader(buf.Bytes()), TEXT, 3) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.IsNull(0) { + t.Error("row 0 should not be null") + } + if v, _ := col.GetBinary(0); string(v.values) != "hello" { + t.Errorf("row 0: expected \"hello\", got %q", string(v.values)) + } + if !col.IsNull(1) { + t.Error("row 1 should be null") + } + if col.IsNull(2) { + t.Error("row 2 should not be null") + } + if v, _ := col.GetBinary(2); string(v.values) != "world" { + t.Errorf("row 2: expected \"world\", got %q", string(v.values)) + } +} + +func TestInt64ArrayColumnDecoder_WithNull(t *testing.T) { + var buf bytes.Buffer + buf.Write(buildNullIndicatorBytes([]bool{false, true, false})) + _ = binary.Write(&buf, binary.BigEndian, int64(100)) + _ = binary.Write(&buf, binary.BigEndian, int64(200)) + + col, err := (&Int64ArrayColumnDecoder{}).ReadColumn(bytes.NewReader(buf.Bytes()), INT64, 3) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.IsNull(0) { + t.Error("row 0 should not be null") + } + if v, _ := col.GetLong(0); v != 100 { + t.Errorf("row 0: expected 100, got %d", v) + } + if !col.IsNull(1) { + t.Error("row 1 should be null") + } + if col.IsNull(2) { + t.Error("row 2 should not be null") + } + if v, _ := col.GetLong(2); v != 200 { + t.Errorf("row 2: expected 200, got %d", v) + } +} + +func TestColumnDecoder_ZeroPositionCount(t *testing.T) { + empty := func() *bytes.Reader { return bytes.NewReader([]byte{}) } + + t.Run("Int32ArrayColumnDecoder", func(t *testing.T) { + col, err := (&Int32ArrayColumnDecoder{}).ReadColumn(empty(), INT32, 0) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 0 { + t.Errorf("expected positionCount=0, got %d", col.GetPositionCount()) + } + }) + + t.Run("Int64ArrayColumnDecoder", func(t *testing.T) { + col, err := (&Int64ArrayColumnDecoder{}).ReadColumn(empty(), INT64, 0) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 0 { + t.Errorf("expected positionCount=0, got %d", col.GetPositionCount()) + } + }) + + t.Run("ByteArrayColumnDecoder", func(t *testing.T) { + col, err := (&ByteArrayColumnDecoder{}).ReadColumn(empty(), BOOLEAN, 0) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 0 { + t.Errorf("expected positionCount=0, got %d", col.GetPositionCount()) + } + }) + + t.Run("BinaryArrayColumnDecoder", func(t *testing.T) { + col, err := (&BinaryArrayColumnDecoder{}).ReadColumn(empty(), TEXT, 0) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 0 { + t.Errorf("expected positionCount=0, got %d", col.GetPositionCount()) + } + }) +} diff --git a/client/session.go b/client/session.go index 98a9fee..fbf1011 100644 --- a/client/session.go +++ b/client/session.go @@ -519,10 +519,15 @@ func (s *Session) ExecuteQueryStatement(sql string, timeoutMs *int64) (*SessionD request.SessionId = s.sessionId request.StatementId = s.requestStatementId resp, err = s.client.ExecuteQueryStatementV2(context.Background(), &request) - if statusErr := VerifySuccess(resp.Status); statusErr == nil { - return NewSessionDataSet(sql, resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize) - } else { - return nil, statusErr + if err == nil { + if resp == nil { + return nil, fmt.Errorf("received nil response after reconnect") + } + if statusErr := VerifySuccess(resp.Status); statusErr == nil { + return NewSessionDataSet(sql, resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize) + } else { + return nil, statusErr + } } } return nil, err @@ -545,10 +550,15 @@ func (s *Session) ExecuteAggregationQuery(paths []string, aggregations []common. if s.reconnect() { request.SessionId = s.sessionId resp, err = s.client.ExecuteAggregationQueryV2(context.Background(), &request) - if statusErr := VerifySuccess(resp.Status); statusErr == nil { - return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize) - } else { - return nil, statusErr + if err == nil { + if resp == nil { + return nil, fmt.Errorf("received nil response after reconnect") + } + if statusErr := VerifySuccess(resp.Status); statusErr == nil { + return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize) + } else { + return nil, statusErr + } } } return nil, err @@ -572,10 +582,15 @@ func (s *Session) ExecuteAggregationQueryWithLegalNodes(paths []string, aggregat if s.reconnect() { request.SessionId = s.sessionId resp, err = s.client.ExecuteAggregationQueryV2(context.Background(), &request) - if statusErr := VerifySuccess(resp.Status); statusErr == nil { - return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize) - } else { - return nil, statusErr + if err == nil { + if resp == nil { + return nil, fmt.Errorf("received nil response after reconnect") + } + if statusErr := VerifySuccess(resp.Status); statusErr == nil { + return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize) + } else { + return nil, statusErr + } } } return nil, err diff --git a/client/sessiondataset.go b/client/sessiondataset.go index d177a44..2cc2282 100644 --- a/client/sessiondataset.go +++ b/client/sessiondataset.go @@ -124,3 +124,7 @@ func (s *SessionDataSet) GetColumnNames() []string { func (s *SessionDataSet) GetColumnTypes() []string { return s.ioTDBRpcDataSet.columnTypeList } + +func (s *SessionDataSet) GetCurrentRowTime() time.Time { + return convertToTimestamp(s.ioTDBRpcDataSet.GetCurrentRowTime(), s.ioTDBRpcDataSet.timeFactor) +} From 310e1e5f2a8f3715a4f182bab3a4abc54245af23 Mon Sep 17 00:00:00 2001 From: shuwenwei Date: Thu, 26 Mar 2026 14:34:33 +0800 Subject: [PATCH 4/4] fix --- client/sessiondataset.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/client/sessiondataset.go b/client/sessiondataset.go index 2cc2282..d177a44 100644 --- a/client/sessiondataset.go +++ b/client/sessiondataset.go @@ -124,7 +124,3 @@ func (s *SessionDataSet) GetColumnNames() []string { func (s *SessionDataSet) GetColumnTypes() []string { return s.ioTDBRpcDataSet.columnTypeList } - -func (s *SessionDataSet) GetCurrentRowTime() time.Time { - return convertToTimestamp(s.ioTDBRpcDataSet.GetCurrentRowTime(), s.ioTDBRpcDataSet.timeFactor) -}