diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7c2888e..95ac189 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,9 +2,9 @@ name: CI on: push: - branches: [ main, master, develop ] + branches: [ main, dev ] pull_request: - branches: [ main, master, develop ] + branches: [ main, dev ] jobs: test: diff --git a/rsql/ast.go b/rsql/ast.go index fb5c0f9..3fd82cf 100644 --- a/rsql/ast.go +++ b/rsql/ast.go @@ -8,6 +8,7 @@ import ( "github.com/rulego/streamsql/functions" "github.com/rulego/streamsql/types" + "github.com/rulego/streamsql/utils/cast" "github.com/rulego/streamsql/window" "github.com/rulego/streamsql/aggregator" @@ -58,9 +59,16 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) { windowType = window.TypeSession } - params, err := parseWindowParamsWithType(s.Window.Params, windowType) - if err != nil { - return nil, "", fmt.Errorf("failed to parse window parameters: %w", err) + // Parse window parameters - now returns array directly + params := s.Window.Params + + // Validate and convert parameters based on window type + if len(params) > 0 { + var err error + params, err = validateWindowParams(params, windowType) + if err != nil { + return nil, "", fmt.Errorf("failed to validate window parameters: %w", err) + } } // Check if window processing is needed @@ -80,16 +88,7 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) { if !needWindow && hasAggregation { needWindow = true windowType = window.TypeTumbling - params = map[string]interface{}{ - "size": 10 * time.Second, // Default 10-second window - } - } - - // Handle special configuration for SessionWindow - var groupByKey string - if windowType == window.TypeSession && len(s.GroupBy) > 0 { - // For session window, use the first GROUP BY field as session key - groupByKey = s.GroupBy[0] + params = []interface{}{10 * time.Second} // Default 10-second window } // If no aggregation functions, collect simple fields @@ -105,10 +104,10 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) { simpleFields = append(simpleFields, fieldName+":"+field.Alias) } else { // For fields without alias, check if it's a string literal - _, n, _, _, err := ParseAggregateTypeWithExpression(fieldName) - if err != nil { - return nil, "", err - } + _, n, _, _, err := ParseAggregateTypeWithExpression(fieldName) + if err != nil { + return nil, "", err + } if n != "" { // If string literal, use parsed field name (remove quotes) simpleFields = append(simpleFields, n) @@ -137,11 +136,11 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) { // Build Stream configuration config := types.Config{ WindowConfig: types.WindowConfig{ - Type: windowType, - Params: params, - TsProp: s.Window.TsProp, - TimeUnit: s.Window.TimeUnit, - GroupByKey: groupByKey, + Type: windowType, + Params: params, + TsProp: s.Window.TsProp, + TimeUnit: s.Window.TimeUnit, + GroupByKeys: extractGroupFields(s), }, GroupFields: extractGroupFields(s), SelectFields: aggs, @@ -245,9 +244,9 @@ func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateTy for _, f := range fields { if alias := f.Alias; alias != "" { t, n, _, _, parseErr := ParseAggregateTypeWithExpression(f.Expression) - if parseErr != nil { - return nil, nil, parseErr - } + if parseErr != nil { + return nil, nil, parseErr + } if t != "" { // Use alias as key for aggregator, not field name selectFields[alias] = t @@ -287,11 +286,11 @@ func detectNestedAggregationRecursive(expr string, inAggregation bool) error { // 使用正则表达式匹配函数调用模式 pattern := regexp.MustCompile(`(?i)([a-z_]+)\s*\(`) matches := pattern.FindAllStringSubmatchIndex(expr, -1) - + for _, match := range matches { funcStart := match[0] funcName := strings.ToLower(expr[match[2]:match[3]]) - + // 检查函数是否为聚合函数 if fn, exists := functions.Get(funcName); exists { switch fn.GetType() { @@ -300,14 +299,14 @@ func detectNestedAggregationRecursive(expr string, inAggregation bool) error { if inAggregation { return fmt.Errorf("aggregate function calls cannot be nested") } - + // 找到该函数的参数部分 funcEnd := findMatchingParenInternal(expr, funcStart+len(funcName)) if funcEnd > funcStart { // 提取函数参数 paramStart := funcStart + len(funcName) + 1 params := expr[paramStart:funcEnd] - + // 在聚合函数参数内部递归检查 if err := detectNestedAggregationRecursive(params, true); err != nil { return err @@ -316,7 +315,7 @@ func detectNestedAggregationRecursive(expr string, inAggregation bool) error { } } } - + return nil } @@ -697,43 +696,103 @@ func extractSimpleField(fieldExpr string) string { return fieldExpr } -func parseWindowParams(params []interface{}) (map[string]interface{}, error) { - return parseWindowParamsWithType(params, "") -} +// validateWindowParams validates and converts window parameters based on window type +// Returns validated parameters array with proper types +func validateWindowParams(params []interface{}, windowType string) ([]interface{}, error) { + if len(params) == 0 { + return params, nil + } -func parseWindowParamsWithType(params []interface{}, windowType string) (map[string]interface{}, error) { - result := make(map[string]interface{}) - var key string + validated := make([]interface{}, 0, len(params)) + + if windowType == window.TypeCounting { + // CountingWindow expects integer count as first parameter + if len(params) == 0 { + return nil, fmt.Errorf("counting window requires at least one parameter") + } + + // Convert first parameter to int using cast utility + count, err := cast.ToIntE(params[0]) + if err != nil { + return nil, fmt.Errorf("invalid count parameter: %w", err) + } + + if count <= 0 { + return nil, fmt.Errorf("counting window count must be positive, got: %d", count) + } + + validated = append(validated, count) + + // Add any additional parameters + if len(params) > 1 { + validated = append(validated, params[1:]...) + } + + return validated, nil + } + + // Helper function to convert a value to time.Duration + // For numeric types, treats them as seconds + // For strings, uses time.ParseDuration + convertToDuration := func(val interface{}) (time.Duration, error) { + switch v := val.(type) { + case time.Duration: + return v, nil + case string: + // Use ToDurationE which handles string parsing + return cast.ToDurationE(v) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + // Treat numeric integers as seconds + return time.Duration(cast.ToInt(v)) * time.Second, nil + case float32, float64: + // Treat numeric floats as seconds + return time.Duration(int(cast.ToFloat64(v))) * time.Second, nil + default: + // Try ToDurationE as fallback + return cast.ToDurationE(v) + } + } + + if windowType == window.TypeSession { + // SessionWindow expects timeout duration as first parameter + if len(params) == 0 { + return nil, fmt.Errorf("session window requires at least one parameter") + } + + timeout, err := convertToDuration(params[0]) + if err != nil { + return nil, fmt.Errorf("invalid timeout duration: %w", err) + } + + if timeout <= 0 { + return nil, fmt.Errorf("session window timeout must be positive, got: %v", timeout) + } + + validated = append(validated, timeout) + + // Add any additional parameters + if len(params) > 1 { + validated = append(validated, params[1:]...) + } + + return validated, nil + } + + // For TumblingWindow and SlidingWindow, convert parameters to time.Duration for index, v := range params { - if windowType == window.TypeSession { - // First parameter for SessionWindow is timeout - if index == 0 { - key = "timeout" - } else { - key = fmt.Sprintf("param%d", index) - } - } else { - // Parameters for other window types - if index == 0 { - key = "size" - } else if index == 1 { - key = "slide" - } else { - key = "offset" - } + dur, err := convertToDuration(v) + if err != nil { + return nil, fmt.Errorf("invalid duration parameter at index %d: %w", index, err) } - if s, ok := v.(string); ok { - dur, err := time.ParseDuration(s) - if err != nil { - return nil, fmt.Errorf("invalid %s duration: %w", s, err) - } - result[key] = dur - } else { - return nil, fmt.Errorf("%s parameter must be string format (like '5s')", s) + + if dur <= 0 { + return nil, fmt.Errorf("duration parameter at index %d must be positive, got: %v", index, dur) } + + validated = append(validated, dur) } - return result, nil + return validated, nil } func parseAggregateExpression(expr string) string { @@ -958,7 +1017,7 @@ func parseComplexAggExpressionInternal(expr string) ([]types.AggregationFieldInf if err := detectNestedAggregation(expr); err != nil { return nil, "", err } - + // 使用改进的递归解析方法 aggFields, exprTemplate := parseNestedFunctionsInternal(expr, make([]types.AggregationFieldInfo, 0)) return aggFields, exprTemplate, nil diff --git a/rsql/ast_test.go b/rsql/ast_test.go index 4cdc1cf..3090212 100644 --- a/rsql/ast_test.go +++ b/rsql/ast_test.go @@ -251,8 +251,8 @@ func TestSelectStatementEdgeCases(t *testing.T) { if config2.WindowConfig.Type != window.TypeSession { t.Errorf("Expected session window, got %v", config2.WindowConfig.Type) } - if config2.WindowConfig.GroupByKey != "user_id" { - t.Errorf("Expected GroupByKey to be 'user_id', got %s", config2.WindowConfig.GroupByKey) + if len(config2.WindowConfig.GroupByKeys) == 0 || config2.WindowConfig.GroupByKeys[0] != "user_id" { + t.Errorf("Expected GroupByKeys to contain 'user_id', got %v", config2.WindowConfig.GroupByKeys) } } diff --git a/rsql/coverage_test.go b/rsql/coverage_test.go index d0efecf..8c75911 100644 --- a/rsql/coverage_test.go +++ b/rsql/coverage_test.go @@ -6,6 +6,7 @@ import ( "github.com/rulego/streamsql/aggregator" "github.com/rulego/streamsql/types" + "github.com/rulego/streamsql/window" ) // TestParseSmartParameters 测试智能参数解析函数 @@ -202,6 +203,12 @@ func TestParseWindowParams(t *testing.T) { windowType: "SLIDINGWINDOW", expectError: false, }, + { + name: "计数窗口参数", + params: []interface{}{100}, + windowType: "COUNTINGWINDOW", + expectError: false, + }, { name: "无效持续时间", params: []interface{}{"invalid"}, @@ -212,7 +219,7 @@ func TestParseWindowParams(t *testing.T) { name: "非字符串参数", params: []interface{}{123}, windowType: "TUMBLINGWINDOW", - expectError: true, + expectError: false, // 整数参数会被视为秒数,这是有效的 }, { name: "空参数", @@ -224,15 +231,24 @@ func TestParseWindowParams(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var result map[string]interface{} + var result []interface{} var err error - if tt.windowType == "SESSIONWINDOW" { - result, err = parseWindowParamsWithType(tt.params, "SESSIONWINDOW") - } else { - result, err = parseWindowParams(tt.params) + // Convert window type to internal format + windowType := "" + switch tt.windowType { + case "SESSIONWINDOW": + windowType = window.TypeSession + case "TUMBLINGWINDOW": + windowType = window.TypeTumbling + case "SLIDINGWINDOW": + windowType = window.TypeSliding + case "COUNTINGWINDOW": + windowType = window.TypeCounting } + result, err = validateWindowParams(tt.params, windowType) + if tt.expectError { if err == nil { t.Errorf("Expected error but got none") diff --git a/rsql/parser.go b/rsql/parser.go index 20b9334..26fda55 100644 --- a/rsql/parser.go +++ b/rsql/parser.go @@ -506,44 +506,46 @@ func (p *Parser) parseWhere(stmt *SelectStatement) error { } func (p *Parser) parseWindowFunction(stmt *SelectStatement, winType string) error { - p.lexer.NextToken() // 跳过( + nextTok := p.lexer.NextToken() // 读取下一个 token,应该是 '(' + if nextTok.Type != TokenLParen { + return fmt.Errorf("expected '(' after window function %s, got %s (type: %v)", winType, nextTok.Value, nextTok.Type) + } + var params []interface{} - - // 设置最大次数限制,防止无限循环 maxIterations := 100 iterations := 0 - for p.lexer.peekChar() != ')' { + // Parse parameters until we find the closing parenthesis + for { iterations++ - // 安全检查:防止无限循环 if iterations > maxIterations { - return errors.New("window function parameter parsing exceeded maximum iterations, possible syntax error") + return fmt.Errorf("window function parameter parsing exceeded maximum iterations") } + // Read the next token first valTok := p.lexer.NextToken() + + // If we hit the closing parenthesis or EOF, break if valTok.Type == TokenRParen || valTok.Type == TokenEOF { break } + + // Skip commas if valTok.Type == TokenComma { continue } - //valTok := p.lexer.NextToken() + // Handle quoted values if strings.HasPrefix(valTok.Value, "'") && strings.HasSuffix(valTok.Value, "'") { valTok.Value = strings.Trim(valTok.Value, "'") } + + // Add the parameter value params = append(params, convertValue(valTok.Value)) } - if &stmt.Window != nil { stmt.Window.Params = params stmt.Window.Type = winType - } else { - stmt.Window = WindowDefinition{ - Type: winType, - Params: params, - } - } return nil } @@ -593,7 +595,9 @@ func (p *Parser) parseGroupBy(stmt *SelectStatement) error { hasWindowFunction := false if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession { hasWindowFunction = true - _ = p.parseWindowFunction(stmt, tok.Value) + if err := p.parseWindowFunction(stmt, tok.Value); err != nil { + return err + } } hasGroupBy := false @@ -633,7 +637,15 @@ func (p *Parser) parseGroupBy(stmt *SelectStatement) error { continue } if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession { - _ = p.parseWindowFunction(stmt, tok.Value) + if err := p.parseWindowFunction(stmt, tok.Value); err != nil { + return err + } + // After parsing window function, skip adding it to GroupBy and continue + continue + } + + // Skip right parenthesis tokens (they should be consumed by parseWindowFunction) + if tok.Type == TokenRParen { continue } diff --git a/stream/coverage_test.go b/stream/coverage_test.go index 97348e2..9e739e6 100644 --- a/stream/coverage_test.go +++ b/stream/coverage_test.go @@ -37,7 +37,7 @@ func TestDataProcessor_ApplyDistinct(t *testing.T) { }, WindowConfig: types.WindowConfig{ Type: "tumbling", - Params: map[string]interface{}{"size": 1 * time.Second}, + Params: []interface{}{1 * time.Second}, }, } stream, err := NewStream(config) @@ -79,7 +79,7 @@ func TestDataProcessor_ApplyHavingFilter(t *testing.T) { Having: "temperature > 25", WindowConfig: types.WindowConfig{ Type: "tumbling", - Params: map[string]interface{}{"size": 1 * time.Second}, + Params: []interface{}{1 * time.Second}, }, } stream, err := NewStream(config) @@ -120,7 +120,7 @@ func TestDataProcessor_ApplyHavingWithCaseExpression(t *testing.T) { Having: "CASE WHEN temperature > 30 THEN 1 WHEN status = 'active' THEN 1 ELSE 0 END", WindowConfig: types.WindowConfig{ Type: "tumbling", - Params: map[string]interface{}{"size": 1 * time.Second}, + Params: []interface{}{1 * time.Second}, }, } stream, err := NewStream(config) @@ -161,7 +161,7 @@ func TestDataProcessor_ApplyHavingWithCondition(t *testing.T) { Having: "temperature > 25", WindowConfig: types.WindowConfig{ Type: "tumbling", - Params: map[string]interface{}{"size": 1 * time.Second}, + Params: []interface{}{1 * time.Second}, }, } stream, err := NewStream(config) @@ -541,7 +541,7 @@ func TestStream_ProcessSync(t *testing.T) { }, WindowConfig: types.WindowConfig{ Type: "tumbling", - Params: map[string]interface{}{"size": 1 * time.Second}, + Params: []interface{}{1 * time.Second}, }, } aggStream, err := NewStream(aggConfig) diff --git a/stream/processor_data_test.go b/stream/processor_data_test.go index 8225c24..35e070c 100644 --- a/stream/processor_data_test.go +++ b/stream/processor_data_test.go @@ -56,7 +56,7 @@ func TestDataProcessor_InitializeAggregator(t *testing.T) { }, WindowConfig: types.WindowConfig{ Type: "tumbling", - Params: map[string]interface{}{"size": 1 * time.Second}, + Params: []interface{}{1 * time.Second}, }, } stream, err := NewStream(config) @@ -90,7 +90,7 @@ func TestDataProcessor_RegisterExpressionCalculator(t *testing.T) { }, WindowConfig: types.WindowConfig{ Type: "tumbling", - Params: map[string]interface{}{"size": 1 * time.Second}, + Params: []interface{}{1 * time.Second}, }, } stream, err := NewStream(config) diff --git a/stream/stream_factory.go b/stream/stream_factory.go index 4e12e3c..9558c75 100644 --- a/stream/stream_factory.go +++ b/stream/stream_factory.go @@ -99,11 +99,8 @@ func (sf *StreamFactory) createStreamWithUnifiedConfig(config types.Config) (*St func (sf *StreamFactory) createWindow(config types.Config) (window.Window, error) { // Pass unified performance configuration to window windowConfig := config.WindowConfig - if windowConfig.Params == nil { - windowConfig.Params = make(map[string]interface{}) - } - // Pass complete performance configuration to window - windowConfig.Params[PerformanceConfigKey] = config.PerformanceConfig + // Set performance configuration directly + windowConfig.PerformanceConfig = config.PerformanceConfig return window.CreateWindow(windowConfig) } diff --git a/stream/stream_test.go b/stream/stream_test.go index d48aa24..f879f13 100644 --- a/stream/stream_test.go +++ b/stream/stream_test.go @@ -53,7 +53,7 @@ func TestStreamBasicOperations(t *testing.T) { WindowConfig: types.WindowConfig{ Type: "tumbling", TimeUnit: 1000, - Params: map[string]interface{}{"size": 1 * time.Second}, + Params: []interface{}{1 * time.Second}, }, }, testFunc: "withWindow", @@ -146,7 +146,7 @@ func TestStreamBasicFunctionality(t *testing.T) { config: types.Config{ WindowConfig: types.WindowConfig{ Type: "tumbling", - Params: map[string]interface{}{"size": 500 * time.Millisecond}, + Params: []interface{}{500 * time.Millisecond}, }, GroupFields: []string{"device"}, SelectFields: map[string]aggregator.AggregateType{ @@ -170,7 +170,7 @@ func TestStreamBasicFunctionality(t *testing.T) { config: types.Config{ WindowConfig: types.WindowConfig{ Type: "tumbling", - Params: map[string]interface{}{"size": 500 * time.Millisecond}, + Params: []interface{}{500 * time.Millisecond}, }, GroupFields: []string{"device"}, SelectFields: map[string]aggregator.AggregateType{ @@ -255,7 +255,7 @@ func TestStreamWithoutFilter(t *testing.T) { config := types.Config{ WindowConfig: types.WindowConfig{ Type: "sliding", - Params: map[string]interface{}{"size": 2 * time.Second, "slide": 1 * time.Second}, + Params: []interface{}{2 * time.Second, 1 * time.Second}, }, GroupFields: []string{"device"}, SelectFields: map[string]aggregator.AggregateType{ @@ -510,10 +510,8 @@ func TestStreamAggregationQuery(t *testing.T) { }, NeedWindow: true, WindowConfig: types.WindowConfig{ - Type: "tumbling", - Params: map[string]interface{}{ - "size": "5s", - }, + Type: "tumbling", + Params: []interface{}{5 * time.Second}, }, } stream, err := NewStream(config) @@ -730,7 +728,7 @@ func TestStreamWithWindowAndAggregation(t *testing.T) { SimpleFields: []string{"name", "age"}, WindowConfig: types.WindowConfig{ Type: "tumbling", - Params: map[string]interface{}{"size": 100 * time.Millisecond}, + Params: []interface{}{100 * time.Millisecond}, }, SelectFields: map[string]aggregator.AggregateType{ "avg_age": aggregator.Avg, @@ -1214,10 +1212,8 @@ func TestStreamWindowEdgeCasesEnhanced(t *testing.T) { config: func() types.Config { c := types.NewConfig() c.WindowConfig = types.WindowConfig{ - Type: "tumbling", - Params: map[string]interface{}{ - "size": 1 * time.Nanosecond, // 极小时间窗口 - }, + Type: "tumbling", + Params: []interface{}{1 * time.Nanosecond}, // 极小时间窗口 TimeUnit: 1 * time.Nanosecond, } c.NeedWindow = true @@ -1230,10 +1226,8 @@ func TestStreamWindowEdgeCasesEnhanced(t *testing.T) { name: "极大时间窗口", config: types.Config{ WindowConfig: types.WindowConfig{ - Type: "tumbling", - Params: map[string]interface{}{ - "size": 8760 * time.Hour, // 1年 - }, + Type: "tumbling", + Params: []interface{}{8760 * time.Hour}, // 1年 TimeUnit: 8760 * time.Hour, }, NeedWindow: true, @@ -1245,11 +1239,8 @@ func TestStreamWindowEdgeCasesEnhanced(t *testing.T) { name: "滑动窗口零滑动", config: types.Config{ WindowConfig: types.WindowConfig{ - Type: "sliding", - Params: map[string]interface{}{ - "size": 1 * time.Second, - "slide": 1 * time.Millisecond, // 很小的滑动间隔 - }, + Type: "sliding", + Params: []interface{}{1 * time.Second, 1 * time.Millisecond}, // 很小的滑动间隔 TimeUnit: 1 * time.Second, }, NeedWindow: true, @@ -1298,10 +1289,8 @@ func TestStreamUnifiedConfigIntegration(t *testing.T) { config := types.Config{ NeedWindow: true, WindowConfig: types.WindowConfig{ - Type: "tumbling", - Params: map[string]interface{}{ - "size": "5s", - }, + Type: "tumbling", + Params: []interface{}{5 * time.Second}, }, SelectFields: map[string]aggregator.AggregateType{ "value": aggregator.Count, @@ -1338,10 +1327,8 @@ func TestStreamUnifiedConfigPerformanceImpact(t *testing.T) { config := types.Config{ NeedWindow: true, WindowConfig: types.WindowConfig{ - Type: "tumbling", - Params: map[string]interface{}{ - "size": "1s", - }, + Type: "tumbling", + Params: []interface{}{time.Second}, }, SelectFields: map[string]aggregator.AggregateType{ "value": aggregator.Sum, @@ -1410,10 +1397,8 @@ func TestStreamUnifiedConfigErrorHandling(t *testing.T) { config: types.Config{ NeedWindow: true, WindowConfig: types.WindowConfig{ - Type: "invalid_window_type", - Params: map[string]interface{}{ - "size": "5s", - }, + Type: "invalid_window_type", + Params: []interface{}{5 * time.Second}, }, SelectFields: map[string]aggregator.AggregateType{ "value": aggregator.Count, @@ -1429,7 +1414,7 @@ func TestStreamUnifiedConfigErrorHandling(t *testing.T) { NeedWindow: true, WindowConfig: types.WindowConfig{ Type: "tumbling", - Params: map[string]interface{}{}, + Params: []interface{}{}, }, SelectFields: map[string]aggregator.AggregateType{ "value": aggregator.Count, @@ -1444,10 +1429,8 @@ func TestStreamUnifiedConfigErrorHandling(t *testing.T) { config: types.Config{ NeedWindow: true, WindowConfig: types.WindowConfig{ - Type: "tumbling", - Params: map[string]interface{}{ - "size": "5s", - }, + Type: "tumbling", + Params: []interface{}{5 * time.Second}, }, SelectFields: map[string]aggregator.AggregateType{ "value": aggregator.Count, @@ -1802,10 +1785,8 @@ func TestStreamFactory_CreateStreamWithWindow(t *testing.T) { SimpleFields: []string{"name", "age"}, NeedWindow: true, WindowConfig: types.WindowConfig{ - Type: "tumbling", - Params: map[string]interface{}{ - "size": "5s", - }, + Type: "tumbling", + Params: []interface{}{5 * time.Second}, }, } @@ -1843,10 +1824,8 @@ func TestStreamFactory_CreateWindow(t *testing.T) { factory := NewStreamFactory() config := types.Config{ WindowConfig: types.WindowConfig{ - Type: "tumbling", - Params: map[string]interface{}{ - "size": "5s", - }, + Type: "tumbling", + Params: []interface{}{5 * time.Second}, }, PerformanceConfig: types.DefaultPerformanceConfig(), } diff --git a/stream/stream_window_test.go b/stream/stream_window_test.go index 4a4e6d0..cd677df 100644 --- a/stream/stream_window_test.go +++ b/stream/stream_window_test.go @@ -16,7 +16,7 @@ func TestWindowSlotAggregation(t *testing.T) { config := types.Config{ WindowConfig: types.WindowConfig{ Type: "sliding", - Params: map[string]interface{}{"size": 2 * time.Second, "slide": 1 * time.Second}, + Params: []interface{}{2 * time.Second, 1 * time.Second}, TsProp: "ts", }, GroupFields: []string{"device"}, @@ -103,44 +103,37 @@ func TestWindowTypes(t *testing.T) { tests := []struct { name string windowType string - windowParams map[string]interface{} + windowParams []interface{} expectError bool }{ { - name: "Tumbling Window", - windowType: "tumbling", - windowParams: map[string]interface{}{ - "size": "5s", - }, - expectError: false, + name: "Tumbling Window", + windowType: "tumbling", + windowParams: []interface{}{5 * time.Second}, + expectError: false, }, { - name: "Sliding Window", - windowType: "sliding", - windowParams: map[string]interface{}{ - "size": "10s", - "slide": "5s", - }, - expectError: false, + name: "Sliding Window", + windowType: "sliding", + windowParams: []interface{}{10 * time.Second, 5 * time.Second}, + expectError: false, }, { - name: "Session Window", - windowType: "session", - windowParams: map[string]interface{}{ - "timeout": "30s", - }, - expectError: false, + name: "Session Window", + windowType: "session", + windowParams: []interface{}{30 * time.Second}, + expectError: false, }, { name: "Invalid Window Type", windowType: "invalid_window_type", - windowParams: map[string]interface{}{"size": "5s"}, + windowParams: []interface{}{5 * time.Second}, expectError: true, }, { name: "Missing Size Parameter", windowType: "tumbling", - windowParams: map[string]interface{}{}, + windowParams: []interface{}{}, expectError: true, }, } @@ -195,7 +188,7 @@ func TestAggregationTypes(t *testing.T) { config := types.Config{ WindowConfig: types.WindowConfig{ Type: "tumbling", - Params: map[string]interface{}{"size": 500 * time.Millisecond}, + Params: []interface{}{500 * time.Millisecond}, }, GroupFields: []string{"group"}, SelectFields: map[string]aggregator.AggregateType{ diff --git a/streamsql_counting_window_test.go b/streamsql_counting_window_test.go new file mode 100644 index 0000000..1efa16c --- /dev/null +++ b/streamsql_counting_window_test.go @@ -0,0 +1,147 @@ +package streamsql + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSQLCountingWindow_GroupByDevice(t *testing.T) { + ssql := New() + defer ssql.Stop() + + sql := ` + SELECT deviceId, + COUNT(*) as cnt + FROM stream + GROUP BY deviceId, CountingWindow(10) + ` + err := ssql.Execute(sql) + require.NoError(t, err) + + ch := make(chan []map[string]interface{}, 4) + ssql.AddSink(func(results []map[string]interface{}) { + ch <- results + }) + + for i := 0; i < 30; i++ { + ssql.Emit(map[string]interface{}{ + "deviceId": "sensor001", + "temperature": i, + "timestamp": time.Now(), + }) + } + + // Expect 3 batches, each with one row for deviceId=sensor001 + for batch := 0; batch < 3; batch++ { + select { + case res := <-ch: + require.Len(t, res, 1) + row := res[0] + assert.Equal(t, "sensor001", row["deviceId"]) + assert.Equal(t, float64(10), row["cnt"]) + case <-time.After(5 * time.Second): + t.Fatalf("timeout waiting for batch %d", batch+1) + } + } +} + +func TestSQLCountingWindow_GroupedCounting_MixedDevices(t *testing.T) { + ssql := New() + defer ssql.Stop() + + sql := ` + SELECT deviceId, + AVG(temperature) as avg_temp + FROM stream + GROUP BY deviceId, CountingWindow(10) + ` + err := ssql.Execute(sql) + require.NoError(t, err) + + ch := make(chan []map[string]interface{}, 8) + ssql.AddSink(func(results []map[string]interface{}) { ch <- results }) + + for i := 0; i < 10; i++ { + ssql.Emit(map[string]interface{}{"deviceId": "A", "temperature": i, "timestamp": time.Now()}) + ssql.Emit(map[string]interface{}{"deviceId": "B", "temperature": i, "timestamp": time.Now()}) + } + + ids := make(map[string]bool) + for k := 0; k < 2; k++ { + select { + case res := <-ch: + require.Len(t, res, 1) + id := res[0]["deviceId"].(string) + ids[id] = true + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } + } + assert.True(t, ids["A"]) + assert.True(t, ids["B"]) +} + +func TestSQLCountingWindow_MultiKeyGroupedCounting(t *testing.T) { + ssql := New() + defer ssql.Stop() + + sql := ` + SELECT deviceId, region, + COUNT(*) as cnt, + AVG(temperature) as avg_temp, + MIN(temperature) as min_temp + FROM stream + GROUP BY deviceId, region, CountingWindow(5) + ` + err := ssql.Execute(sql) + require.NoError(t, err) + + ch := make(chan []map[string]interface{}, 8) + ssql.AddSink(func(results []map[string]interface{}) { ch <- results }) + + for i := 0; i < 5; i++ { + ssql.Emit(map[string]interface{}{"deviceId": "A", "region": "R1", "temperature": i, "timestamp": time.Now()}) + ssql.Emit(map[string]interface{}{"deviceId": "B", "region": "R1", "temperature": i + 10, "timestamp": time.Now()}) + ssql.Emit(map[string]interface{}{"deviceId": "A", "region": "R2", "temperature": i + 20, "timestamp": time.Now()}) + ssql.Emit(map[string]interface{}{"deviceId": "B", "region": "R2", "temperature": i + 30, "timestamp": time.Now()}) + } + + type agg struct { + cnt float64 + avg float64 + min float64 + } + got := make(map[string]agg) + for k := 0; k < 4; k++ { + select { + case res := <-ch: + require.Len(t, res, 1) + id := res[0]["deviceId"].(string) + region := res[0]["region"].(string) + cnt := res[0]["cnt"].(float64) + avg := res[0]["avg_temp"].(float64) + min := res[0]["min_temp"].(float64) + got[id+"|"+region] = agg{cnt: cnt, avg: avg, min: min} + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } + } + // Expect 4 combinations all counted to 5, with known avg/min + assert.Equal(t, float64(5), got["A|R1"].cnt) + assert.Equal(t, float64(5), got["B|R1"].cnt) + assert.Equal(t, float64(5), got["A|R2"].cnt) + assert.Equal(t, float64(5), got["B|R2"].cnt) + + assert.InEpsilon(t, 2.0, got["A|R1"].avg, 0.0001) + assert.InEpsilon(t, 12.0, got["B|R1"].avg, 0.0001) + assert.InEpsilon(t, 22.0, got["A|R2"].avg, 0.0001) + assert.InEpsilon(t, 32.0, got["B|R2"].avg, 0.0001) + + assert.Equal(t, 0.0, got["A|R1"].min) + assert.InEpsilon(t, 10.0, got["B|R1"].min, 0.0001) + assert.InEpsilon(t, 20.0, got["A|R2"].min, 0.0001) + assert.InEpsilon(t, 30.0, got["B|R2"].min, 0.0001) +} diff --git a/streamsql_session_window_test.go b/streamsql_session_window_test.go new file mode 100644 index 0000000..8617cd7 --- /dev/null +++ b/streamsql_session_window_test.go @@ -0,0 +1,176 @@ +package streamsql + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSQLSessionWindow_SingleKey(t *testing.T) { + ssql := New() + defer ssql.Stop() + + sql := ` + SELECT deviceId, + COUNT(*) as cnt + FROM stream + GROUP BY deviceId, SessionWindow('300ms') + ` + err := ssql.Execute(sql) + require.NoError(t, err) + + ch := make(chan []map[string]interface{}, 4) + ssql.AddSink(func(results []map[string]interface{}) { ch <- results }) + + for i := 0; i < 5; i++ { + ssql.Emit(map[string]interface{}{"deviceId": "sensor001", "timestamp": time.Now()}) + time.Sleep(50 * time.Millisecond) + } + + time.Sleep(600 * time.Millisecond) + + select { + case res := <-ch: + require.Len(t, res, 1) + row := res[0] + assert.Equal(t, "sensor001", row["deviceId"]) + assert.Equal(t, float64(5), row["cnt"]) + case <-time.After(2 * time.Second): + t.Fatal("timeout") + } +} + +func TestSQLSessionWindow_GroupedSession_MixedDevices(t *testing.T) { + ssql := New() + defer ssql.Stop() + + sql := ` + SELECT deviceId, + AVG(temperature) as avg_temp + FROM stream + GROUP BY deviceId, SessionWindow('200ms') + ` + err := ssql.Execute(sql) + require.NoError(t, err) + + ch := make(chan []map[string]interface{}, 8) + ssql.AddSink(func(results []map[string]interface{}) { ch <- results }) + + // Emit data for two different devices in interleaved pattern + for i := 0; i < 5; i++ { + ssql.Emit(map[string]interface{}{"deviceId": "A", "temperature": float64(i), "timestamp": time.Now()}) + ssql.Emit(map[string]interface{}{"deviceId": "B", "temperature": float64(i + 10), "timestamp": time.Now()}) + time.Sleep(30 * time.Millisecond) + } + + // Wait for session timeout + time.Sleep(400 * time.Millisecond) + + ids := make(map[string]bool) + avgTemps := make(map[string]float64) + for k := 0; k < 2; k++ { + select { + case res := <-ch: + require.Len(t, res, 1) + id := res[0]["deviceId"].(string) + avgTemp := res[0]["avg_temp"].(float64) + ids[id] = true + avgTemps[id] = avgTemp + case <-time.After(2 * time.Second): + t.Fatal("timeout") + } + } + assert.True(t, ids["A"]) + assert.True(t, ids["B"]) + // Verify average temperatures: A should have avg of 0-4 = 2.0, B should have avg of 10-14 = 12.0 + assert.InEpsilon(t, 2.0, avgTemps["A"], 0.1) + assert.InEpsilon(t, 12.0, avgTemps["B"], 0.1) +} + +func TestSQLSessionWindow_MultiKeyGroupedSession(t *testing.T) { + ssql := New() + defer ssql.Stop() + + sql := ` + SELECT deviceId, region, + COUNT(*) as cnt, + AVG(temperature) as avg_temp, + MIN(temperature) as min_temp, + MAX(temperature) as max_temp + FROM stream + GROUP BY deviceId, region, SessionWindow('200ms') + ` + err := ssql.Execute(sql) + require.NoError(t, err) + + ch := make(chan []map[string]interface{}, 8) + ssql.AddSink(func(results []map[string]interface{}) { ch <- results }) + + // Emit data for 4 different combinations: A|R1, B|R1, A|R2, B|R2 + for i := 0; i < 4; i++ { + ssql.Emit(map[string]interface{}{"deviceId": "A", "region": "R1", "temperature": float64(i), "timestamp": time.Now()}) + ssql.Emit(map[string]interface{}{"deviceId": "B", "region": "R1", "temperature": float64(i + 10), "timestamp": time.Now()}) + ssql.Emit(map[string]interface{}{"deviceId": "A", "region": "R2", "temperature": float64(i + 20), "timestamp": time.Now()}) + ssql.Emit(map[string]interface{}{"deviceId": "B", "region": "R2", "temperature": float64(i + 30), "timestamp": time.Now()}) + time.Sleep(30 * time.Millisecond) + } + + // Wait for session timeout + time.Sleep(400 * time.Millisecond) + + type agg struct { + cnt float64 + avg float64 + min float64 + max float64 + } + got := make(map[string]agg) + for k := 0; k < 4; k++ { + select { + case res := <-ch: + require.Len(t, res, 1) + id := res[0]["deviceId"].(string) + region := res[0]["region"].(string) + cnt := res[0]["cnt"].(float64) + avg := res[0]["avg_temp"].(float64) + min := res[0]["min_temp"].(float64) + max := res[0]["max_temp"].(float64) + got[id+"|"+region] = agg{cnt: cnt, avg: avg, min: min, max: max} + case <-time.After(2 * time.Second): + t.Fatal("timeout") + } + } + + // Verify all 4 combinations are present + require.Contains(t, got, "A|R1") + require.Contains(t, got, "B|R1") + require.Contains(t, got, "A|R2") + require.Contains(t, got, "B|R2") + + // Verify counts: each combination should have 4 records + assert.Equal(t, float64(4), got["A|R1"].cnt) + assert.Equal(t, float64(4), got["B|R1"].cnt) + assert.Equal(t, float64(4), got["A|R2"].cnt) + assert.Equal(t, float64(4), got["B|R2"].cnt) + + // Verify averages: A|R1: (0+1+2+3)/4 = 1.5, B|R1: (10+11+12+13)/4 = 11.5 + // A|R2: (20+21+22+23)/4 = 21.5, B|R2: (30+31+32+33)/4 = 31.5 + assert.InEpsilon(t, 1.5, got["A|R1"].avg, 0.1) + assert.InEpsilon(t, 11.5, got["B|R1"].avg, 0.1) + assert.InEpsilon(t, 21.5, got["A|R2"].avg, 0.1) + assert.InEpsilon(t, 31.5, got["B|R2"].avg, 0.1) + + // Verify minimums: A|R1: 0, B|R1: 10, A|R2: 20, B|R2: 30 + assert.Equal(t, 0.0, got["A|R1"].min) + assert.Equal(t, 10.0, got["B|R1"].min) + assert.Equal(t, 20.0, got["A|R2"].min) + assert.Equal(t, 30.0, got["B|R2"].min) + + // Verify maximums: A|R1: 3, B|R1: 13, A|R2: 23, B|R2: 33 + assert.Equal(t, 3.0, got["A|R1"].max) + assert.Equal(t, 13.0, got["B|R1"].max) + assert.Equal(t, 23.0, got["A|R2"].max) + assert.Equal(t, 33.0, got["B|R2"].max) +} diff --git a/streamsql_sliding_window_test.go b/streamsql_sliding_window_test.go new file mode 100644 index 0000000..1282fb6 --- /dev/null +++ b/streamsql_sliding_window_test.go @@ -0,0 +1,355 @@ +package streamsql + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSQLSlidingWindow_Basic(t *testing.T) { + ssql := New() + defer ssql.Stop() + + sql := ` + SELECT deviceId, + COUNT(*) as cnt + FROM stream + GROUP BY deviceId, SlidingWindow('10s', '2s') + ` + err := ssql.Execute(sql) + require.NoError(t, err) + + ch := make(chan []map[string]interface{}, 10) + ssql.AddSink(func(results []map[string]interface{}) { + ch <- results + }) + + baseTime := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + for i := 0; i < 12; i++ { + ssql.Emit(map[string]interface{}{ + "deviceId": "sensor001", + "temperature": i, + "timestamp": baseTime.Add(time.Duration(i) * time.Second), + }) + time.Sleep(10 * time.Millisecond) + } + + results := make([][]map[string]interface{}, 0) + timeout := time.After(15 * time.Second) + for { + select { + case res := <-ch: + if len(res) > 0 { + results = append(results, res) + } + case <-timeout: + goto END + } + } + +END: + assert.Greater(t, len(results), 0) + if len(results) > 0 { + firstWindow := results[0] + require.Len(t, firstWindow, 1) + cnt := firstWindow[0]["cnt"].(float64) + assert.Greater(t, cnt, 0.0) + } +} + +func TestSQLSlidingWindow_WithAggregations(t *testing.T) { + ssql := New() + defer ssql.Stop() + + sql := ` + SELECT deviceId, + COUNT(*) as cnt, + AVG(temperature) as avg_temp, + MIN(temperature) as min_temp, + MAX(temperature) as max_temp + FROM stream + GROUP BY deviceId, SlidingWindow('10s', '2s') + ` + err := ssql.Execute(sql) + require.NoError(t, err) + + ch := make(chan []map[string]interface{}, 20) + ssql.AddSink(func(results []map[string]interface{}) { + ch <- results + }) + + baseTime := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + for i := 0; i < 15; i++ { + timestamp := baseTime.Add(time.Duration(i) * time.Second) + temperature := float64(i) + + ssql.Emit(map[string]interface{}{ + "deviceId": "sensor001", + "temperature": temperature, + "timestamp": timestamp, + }) + time.Sleep(10 * time.Millisecond) + } + + time.Sleep(5 * time.Second) + + results := make([][]map[string]interface{}, 0) + timeout := time.After(3 * time.Second) + for { + select { + case res := <-ch: + if len(res) > 0 { + results = append(results, res) + } + case <-timeout: + goto END + } + } + +END: + require.Greater(t, len(results), 0, "至少应该有一个窗口被触发") + + maxCnt := 0.0 + for _, res := range results { + if len(res) > 0 { + cnt := res[0]["cnt"].(float64) + if cnt > maxCnt { + maxCnt = cnt + } + } + } + assert.GreaterOrEqual(t, maxCnt, 8.0, "至少应该有一个窗口包含接近10条数据") + + for i, res := range results { + require.Len(t, res, 1, "每个窗口应该只有一行聚合结果") + row := res[0] + + cnt := row["cnt"].(float64) + avgTemp := row["avg_temp"].(float64) + minTemp := row["min_temp"].(float64) + maxTemp := row["max_temp"].(float64) + + assert.Greater(t, cnt, 0.0, "窗口 %d 计数应该大于0", i+1) + assert.LessOrEqual(t, minTemp, maxTemp, "窗口 %d 最小值应该小于等于最大值", i+1) + assert.LessOrEqual(t, minTemp, avgTemp, "窗口 %d 最小值应该小于等于平均值", i+1) + assert.LessOrEqual(t, avgTemp, maxTemp, "窗口 %d 平均值应该小于等于最大值", i+1) + + if cnt >= 2 { + expectedAvg := (minTemp + maxTemp) / 2.0 + allowedError := (maxTemp - minTemp) / 2.0 + assert.InDelta(t, expectedAvg, avgTemp, allowedError+0.1, + "窗口 %d 平均值应该在最小值和最大值的中间", i+1) + } + } +} + +func TestSQLSlidingWindow_MultipleWindowsAlignment(t *testing.T) { + ssql := New() + defer ssql.Stop() + + sql := ` + SELECT deviceId, + COUNT(*) as cnt, + AVG(temperature) as avg_temp, + MIN(temperature) as min_temp, + MAX(temperature) as max_temp + FROM stream + GROUP BY deviceId, SlidingWindow('10s', '2s') + ` + err := ssql.Execute(sql) + require.NoError(t, err) + + ch := make(chan []map[string]interface{}, 20) + windowResults := make([][]map[string]interface{}, 0) + ssql.AddSink(func(results []map[string]interface{}) { + ch <- results + }) + + baseTime := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + for i := 0; i < 15; i++ { + ssql.Emit(map[string]interface{}{ + "deviceId": "sensor001", + "temperature": float64(i), + "timestamp": baseTime.Add(time.Duration(i) * time.Second), + }) + time.Sleep(10 * time.Millisecond) + } + + time.Sleep(8 * time.Second) + + timeout := time.After(2 * time.Second) + for { + select { + case res := <-ch: + if len(res) > 0 { + windowResults = append(windowResults, res) + } + case <-timeout: + goto END + } + } + +END: + require.Greater(t, len(windowResults), 0, "应该至少触发一个窗口") + + for i, res := range windowResults { + require.Len(t, res, 1, "窗口 %d 应该只有一行聚合结果", i+1) + row := res[0] + + cnt := row["cnt"].(float64) + avgTemp := row["avg_temp"].(float64) + minTemp := row["min_temp"].(float64) + maxTemp := row["max_temp"].(float64) + + assert.Greater(t, cnt, 0.0, "窗口 %d 计数应该大于0", i+1) + assert.LessOrEqual(t, minTemp, maxTemp, "窗口 %d 最小值应该小于等于最大值", i+1) + assert.LessOrEqual(t, minTemp, avgTemp, "窗口 %d 最小值应该小于等于平均值", i+1) + assert.LessOrEqual(t, avgTemp, maxTemp, "窗口 %d 平均值应该小于等于最大值", i+1) + + if cnt >= 2 { + expectedAvg := (minTemp + maxTemp) / 2.0 + allowedError := (maxTemp - minTemp) / 2.0 + assert.InDelta(t, expectedAvg, avgTemp, allowedError+0.1, + "窗口 %d 平均值应该在最小值和最大值的中间", i+1) + } + + assert.LessOrEqual(t, minTemp, 14.0, "窗口 %d 最小值不应该超过14", i+1) + assert.GreaterOrEqual(t, maxTemp, 0.0, "窗口 %d 最大值不应该小于0", i+1) + assert.LessOrEqual(t, cnt, 15.0, "窗口 %d 计数不应该超过15", i+1) + } + + if len(windowResults) > 1 { + firstWindow := windowResults[0] + lastWindow := windowResults[len(windowResults)-1] + + firstCnt := firstWindow[0]["cnt"].(float64) + lastCnt := lastWindow[0]["cnt"].(float64) + firstMin := firstWindow[0]["min_temp"].(float64) + lastMin := lastWindow[0]["min_temp"].(float64) + + assert.GreaterOrEqual(t, firstCnt, lastCnt, + "第一个窗口应该包含不少于最后一个窗口的数据") + assert.LessOrEqual(t, firstMin, lastMin, + "第一个窗口的最小值应该小于等于最后一个窗口的最小值") + } + + allCounts := make([]float64, len(windowResults)) + for i, res := range windowResults { + allCounts[i] = res[0]["cnt"].(float64) + } + + for i := 1; i < len(allCounts); i++ { + prevCnt := allCounts[i-1] + currCnt := allCounts[i] + assert.GreaterOrEqual(t, prevCnt, currCnt, + "窗口计数应该递减或保持不变(由于窗口对齐,可能不完全递减)") + } +} + +func TestSQLSlidingWindow_MultiKeyGrouped(t *testing.T) { + ssql := New() + defer ssql.Stop() + + sql := ` + SELECT deviceId, region, + COUNT(*) as cnt, + AVG(temperature) as avg_temp, + MIN(temperature) as min_temp, + MAX(temperature) as max_temp + FROM stream + GROUP BY deviceId, region, SlidingWindow('5s', '2s') + ` + err := ssql.Execute(sql) + require.NoError(t, err) + + ch := make(chan []map[string]interface{}, 20) + ssql.AddSink(func(results []map[string]interface{}) { + ch <- results + }) + + baseTime := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + for i := 0; i < 8; i++ { + ssql.Emit(map[string]interface{}{ + "deviceId": "A", + "region": "R1", + "temperature": float64(i), + "timestamp": baseTime.Add(time.Duration(i) * time.Second), + }) + ssql.Emit(map[string]interface{}{ + "deviceId": "B", + "region": "R1", + "temperature": float64(i + 10), + "timestamp": baseTime.Add(time.Duration(i) * time.Second), + }) + ssql.Emit(map[string]interface{}{ + "deviceId": "A", + "region": "R2", + "temperature": float64(i + 20), + "timestamp": baseTime.Add(time.Duration(i) * time.Second), + }) + ssql.Emit(map[string]interface{}{ + "deviceId": "B", + "region": "R2", + "temperature": float64(i + 30), + "timestamp": baseTime.Add(time.Duration(i) * time.Second), + }) + time.Sleep(10 * time.Millisecond) + } + + time.Sleep(3 * time.Second) + + type agg struct { + cnt float64 + avg float64 + min float64 + max float64 + } + got := make(map[string][]agg) + + timeout := time.After(2 * time.Second) + for { + select { + case res := <-ch: + if len(res) > 0 { + for _, row := range res { + id := row["deviceId"].(string) + region := row["region"].(string) + key := id + "|" + region + got[key] = append(got[key], agg{ + cnt: row["cnt"].(float64), + avg: row["avg_temp"].(float64), + min: row["min_temp"].(float64), + max: row["max_temp"].(float64), + }) + } + } + case <-timeout: + goto END + } + } + +END: + require.Contains(t, got, "A|R1") + require.Contains(t, got, "B|R1") + require.Contains(t, got, "A|R2") + require.Contains(t, got, "B|R2") + + for key, windows := range got { + assert.Greater(t, len(windows), 0, "组合 %s 应该至少有一个窗口", key) + for i, w := range windows { + assert.Greater(t, w.cnt, 0.0, "组合 %s 窗口 %d 计数应该大于0", key, i+1) + assert.LessOrEqual(t, w.min, w.max, "组合 %s 窗口 %d 最小值应该小于等于最大值", key, i+1) + assert.LessOrEqual(t, w.min, w.avg, "组合 %s 窗口 %d 最小值应该小于等于平均值", key, i+1) + assert.LessOrEqual(t, w.avg, w.max, "组合 %s 窗口 %d 平均值应该小于等于最大值", key, i+1) + + if w.cnt >= 2 { + expectedAvg := (w.min + w.max) / 2.0 + allowedError := (w.max - w.min) / 2.0 + assert.InDelta(t, expectedAvg, w.avg, allowedError+0.1, + "组合 %s 窗口 %d 平均值应该在最小值和最大值的中间", key, i+1) + } + } + } +} diff --git a/types/config.go b/types/config.go index 7a15d90..b1cc8a1 100644 --- a/types/config.go +++ b/types/config.go @@ -34,11 +34,13 @@ type Config struct { // WindowConfig window configuration type WindowConfig struct { - Type string `json:"type"` - Params map[string]interface{} `json:"params"` - TsProp string `json:"tsProp"` - TimeUnit time.Duration `json:"timeUnit"` - GroupByKey string `json:"groupByKey"` // Session window grouping key + Type string `json:"type"` + Params []interface{} `json:"params"` // Window function parameters array + TsProp string `json:"tsProp"` + TimeUnit time.Duration `json:"timeUnit"` + GroupByKeys []string `json:"groupByKeys"` // Multiple grouping keys for keyed windows + PerformanceConfig PerformanceConfig `json:"performanceConfig"` // Performance configuration + Callback func([]Row) `json:"-"` // Callback function (not serialized) } // FieldExpression field expression configuration diff --git a/types/config_test.go b/types/config_test.go index 77186f9..f85fa82 100644 --- a/types/config_test.go +++ b/types/config_test.go @@ -29,10 +29,10 @@ func TestConfig(t *testing.T) { config := &Config{ WindowConfig: WindowConfig{ Type: "tumbling", - Params: map[string]interface{}{"size": "1m"}, + Params: []interface{}{time.Minute}, TsProp: "timestamp", TimeUnit: time.Minute, - GroupByKey: "user_id", + GroupByKeys: []string{"user_id"}, }, GroupFields: []string{"user_id", "category"}, SelectFields: map[string]aggregator.AggregateType{"count": aggregator.Count, "sum": aggregator.Sum}, @@ -127,10 +127,10 @@ func TestConfig(t *testing.T) { func TestWindowConfig(t *testing.T) { windowConfig := WindowConfig{ Type: "sliding", - Params: map[string]interface{}{"size": "5m", "interval": "1m"}, + Params: []interface{}{5 * time.Minute, time.Minute}, TsProp: "event_time", TimeUnit: time.Minute, - GroupByKey: "session_id", + GroupByKeys: []string{"session_id"}, } if windowConfig.Type != "sliding" { @@ -145,20 +145,22 @@ func TestWindowConfig(t *testing.T) { t.Errorf("Expected time unit 'Minute', got '%v'", windowConfig.TimeUnit) } - if windowConfig.GroupByKey != "session_id" { - t.Errorf("Expected group by key 'session_id', got '%s'", windowConfig.GroupByKey) + if len(windowConfig.GroupByKeys) == 0 || windowConfig.GroupByKeys[0] != "session_id" { + t.Errorf("Expected group by keys to contain 'session_id', got %v", windowConfig.GroupByKeys) } if len(windowConfig.Params) != 2 { t.Errorf("Expected 2 parameters, got %d", len(windowConfig.Params)) } - if windowConfig.Params["size"] != "5m" { - t.Errorf("Expected size parameter '5m', got '%v'", windowConfig.Params["size"]) + // Check first parameter (size) + if size, ok := windowConfig.Params[0].(time.Duration); !ok || size != 5*time.Minute { + t.Errorf("Expected size parameter 5m, got '%v'", windowConfig.Params[0]) } - if windowConfig.Params["interval"] != "1m" { - t.Errorf("Expected interval parameter '1m', got '%v'", windowConfig.Params["interval"]) + // Check second parameter (slide/interval) + if slide, ok := windowConfig.Params[1].(time.Duration); !ok || slide != time.Minute { + t.Errorf("Expected slide parameter 1m, got '%v'", windowConfig.Params[1]) } } @@ -447,10 +449,10 @@ func TestComplexConfig(t *testing.T) { config := Config{ WindowConfig: WindowConfig{ Type: "sliding", - Params: map[string]interface{}{"size": "5m", "slide": "1m"}, + Params: []interface{}{5 * time.Minute, time.Minute}, TsProp: "event_time", TimeUnit: time.Minute, - GroupByKey: "session_id", + GroupByKeys: []string{"session_id"}, }, GroupFields: []string{"user_id", "product_category", "region"}, SelectFields: map[string]aggregator.AggregateType{ diff --git a/window/counting_window.go b/window/counting_window.go index 53ccc04..4780b15 100644 --- a/window/counting_window.go +++ b/window/counting_window.go @@ -19,8 +19,9 @@ package window import ( "context" "fmt" + "reflect" + "strings" "sync" - "time" "github.com/rulego/streamsql/utils/cast" "github.com/rulego/streamsql/utils/timex" @@ -31,60 +32,82 @@ import ( var _ Window = (*CountingWindow)(nil) type CountingWindow struct { - config types.WindowConfig - threshold int - count int - mu sync.Mutex - callback func([]types.Row) - dataBuffer []types.Row - outputChan chan []types.Row - ctx context.Context - cancelFunc context.CancelFunc - ticker *time.Ticker - triggerChan chan types.Row + config types.WindowConfig + threshold int + count int + mu sync.Mutex + callback func([]types.Row) + dataBuffer []types.Row + outputChan chan []types.Row + ctx context.Context + cancelFunc context.CancelFunc + triggerChan chan types.Row + keyedBuffer map[string][]types.Row + keyedCount map[string]int + sentCount int64 + droppedCount int64 + stopped bool } func NewCountingWindow(config types.WindowConfig) (*CountingWindow, error) { ctx, cancel := context.WithCancel(context.Background()) - threshold := cast.ToInt(config.Params["count"]) + defer func() { + if cancel != nil { + // cancel will be used in the returned struct + } + }() + + // Get count parameter from params array + if len(config.Params) == 0 { + cancel() + return nil, fmt.Errorf("counting window requires 'count' parameter") + } + + countVal := config.Params[0] + threshold := cast.ToInt(countVal) if threshold <= 0 { - return nil, fmt.Errorf("threshold must be a positive integer") + return nil, fmt.Errorf("threshold must be a positive integer, got: %v", countVal) } // Use unified performance config to get window output buffer size bufferSize := 100 // Default value, counting windows usually have smaller buffers - if perfConfig, exists := config.Params["performanceConfig"]; exists { - if pc, ok := perfConfig.(types.PerformanceConfig); ok { - bufferSize = pc.BufferConfig.WindowOutputSize / 10 // Counting window uses 1/10 of buffer - if bufferSize < 10 { - bufferSize = 10 // Minimum value - } + if (config.PerformanceConfig != types.PerformanceConfig{}) { + bufferSize = config.PerformanceConfig.BufferConfig.WindowOutputSize / 10 // Counting window uses 1/10 of buffer + if bufferSize < 10 { + bufferSize = 10 // Minimum value } } cw := &CountingWindow{ + config: config, threshold: threshold, dataBuffer: make([]types.Row, 0, threshold), outputChan: make(chan []types.Row, bufferSize), ctx: ctx, cancelFunc: cancel, - triggerChan: make(chan types.Row, 3), + triggerChan: make(chan types.Row, bufferSize), + keyedBuffer: make(map[string][]types.Row), + keyedCount: make(map[string]int), } - if callback, ok := config.Params["callback"].(func([]types.Row)); ok { - cw.SetCallback(callback) + // Set callback if provided + if config.Callback != nil { + cw.SetCallback(config.Callback) } return cw, nil } func (cw *CountingWindow) Add(data interface{}) { - // Add data to window data list t := GetTimestamp(data, cw.config.TsProp, cw.config.TimeUnit) row := types.Row{ Data: data, Timestamp: t, } - cw.triggerChan <- row + + select { + case cw.triggerChan <- row: + case <-cw.ctx.Done(): + } } func (cw *CountingWindow) Start() { go func() { @@ -97,39 +120,44 @@ func (cw *CountingWindow) Start() { // Channel closed, exit loop return } + key := cw.getKey(row.Data) cw.mu.Lock() - cw.dataBuffer = append(cw.dataBuffer, row) - cw.count++ - shouldTrigger := cw.count >= cw.threshold - if shouldTrigger { - // Process immediately while holding lock - slot := cw.createSlot(cw.dataBuffer[:cw.threshold]) + buf := append(cw.keyedBuffer[key], row) + cw.keyedBuffer[key] = buf + cw.keyedCount[key] = len(buf) + if cw.keyedCount[key] >= cw.threshold { + slot := cw.createSlot(buf[:cw.threshold]) data := make([]types.Row, cw.threshold) - copy(data, cw.dataBuffer[:cw.threshold]) - // Set Slot field to copied data to avoid modifying original dataBuffer + copy(data, buf[:cw.threshold]) for i := range data { data[i].Slot = slot } - - if len(cw.dataBuffer) > cw.threshold { - remaining := len(cw.dataBuffer) - cw.threshold - newBuffer := make([]types.Row, remaining, cw.threshold) - copy(newBuffer, cw.dataBuffer[cw.threshold:]) - cw.dataBuffer = newBuffer + if len(buf) > cw.threshold { + rem := make([]types.Row, len(buf)-cw.threshold, cw.threshold) + copy(rem, buf[cw.threshold:]) + cw.keyedBuffer[key] = rem } else { - cw.dataBuffer = make([]types.Row, 0, cw.threshold) + cw.keyedBuffer[key] = make([]types.Row, 0, cw.threshold) } - // Reset count - cw.count = len(cw.dataBuffer) + cw.keyedCount[key] = len(cw.keyedBuffer[key]) cw.mu.Unlock() - // Handle callback after releasing lock - go func(data []types.Row) { - if cw.callback != nil { - cw.callback(data) - } - cw.outputChan <- data - }(data) + if cw.callback != nil { + cw.callback(data) + } + + select { + case cw.outputChan <- data: + cw.mu.Lock() + cw.sentCount++ + cw.mu.Unlock() + case <-cw.ctx.Done(): + return + default: + cw.mu.Lock() + cw.droppedCount++ + cw.mu.Unlock() + } } else { cw.mu.Unlock() } @@ -146,11 +174,42 @@ func (cw *CountingWindow) Trigger() { // This method is kept to satisfy Window interface requirements, but actual triggering is handled in Start method } +func (cw *CountingWindow) Stop() { + cw.mu.Lock() + stopped := cw.stopped + if !stopped { + cw.stopped = true + } + cw.mu.Unlock() + + if !stopped { + close(cw.triggerChan) + cw.cancelFunc() + } +} + func (cw *CountingWindow) Reset() { cw.mu.Lock() defer cw.mu.Unlock() + cw.count = 0 cw.dataBuffer = nil + cw.keyedBuffer = make(map[string][]types.Row) + cw.keyedCount = make(map[string]int) + cw.sentCount = 0 + cw.droppedCount = 0 +} + +func (cw *CountingWindow) GetStats() map[string]int64 { + cw.mu.Lock() + defer cw.mu.Unlock() + + return map[string]int64{ + "sentCount": cw.sentCount, + "droppedCount": cw.droppedCount, + "bufferSize": int64(cap(cw.outputChan)), + "bufferUsed": int64(len(cw.outputChan)), + } } func (cw *CountingWindow) OutputChan() <-chan []types.Row { @@ -177,3 +236,32 @@ func (cw *CountingWindow) createSlot(data []types.Row) *types.TimeSlot { return slot } } + +func (cw *CountingWindow) getKey(data interface{}) string { + // Use GroupByKeys array + keys := cw.config.GroupByKeys + if len(keys) == 0 { + return "__global__" + } + v := reflect.ValueOf(data) + keyParts := make([]string, 0, len(keys)) + for _, k := range keys { + var part string + switch v.Kind() { + case reflect.Map: + if v.Type().Key().Kind() == reflect.String { + mv := v.MapIndex(reflect.ValueOf(k)) + if mv.IsValid() { + part = cast.ToString(mv.Interface()) + } + } + case reflect.Struct: + f := v.FieldByName(k) + if f.IsValid() { + part = cast.ToString(f.Interface()) + } + } + keyParts = append(keyParts, part) + } + return strings.Join(keyParts, "|") +} diff --git a/window/counting_window_test.go b/window/counting_window_test.go index a190eba..31a47cc 100644 --- a/window/counting_window_test.go +++ b/window/counting_window_test.go @@ -17,11 +17,9 @@ func TestCountingWindow(t *testing.T) { // Test case 1: Normal operation cw, _ := NewCountingWindow(types.WindowConfig{ - Params: map[string]interface{}{ - "count": 3, - "callback": func(results []interface{}) { + Params: []interface{}{3}, + Callback: func(results []types.Row) { t.Logf("Received results: %v", results) - }, }, }) go cw.Start() @@ -85,9 +83,7 @@ func TestCountingWindow(t *testing.T) { func TestCountingWindowBadThreshold(t *testing.T) { _, err := CreateWindow(types.WindowConfig{ Type: "counting", - Params: map[string]interface{}{ - "count": 0, - }, + Params: []interface{}{0}, }) require.Error(t, err) } diff --git a/window/performance_test.go b/window/performance_test.go index fe41099..febfa8d 100644 --- a/window/performance_test.go +++ b/window/performance_test.go @@ -17,10 +17,7 @@ func TestTumblingWindowPerformance(t *testing.T) { t.Run(fmt.Sprintf("BufferSize_%d", bufferSize), func(t *testing.T) { tw, _ := NewTumblingWindow(types.WindowConfig{ Type: "TumblingWindow", - Params: map[string]interface{}{ - "size": "100ms", - "outputBufferSize": bufferSize, - }, + Params: []interface{}{100 * time.Millisecond}, TsProp: "Ts", }) @@ -47,14 +44,14 @@ func TestTumblingWindowPerformance(t *testing.T) { t.Logf("缓冲区大小: %d", bufferSize) t.Logf("处理时间: %v", elapsed) - t.Logf("发送成功: %d", stats["sent_count"]) - t.Logf("丢弃数量: %d", stats["dropped_count"]) - t.Logf("缓冲区利用率: %d/%d", stats["buffer_used"], stats["buffer_size"]) + t.Logf("发送成功: %d", stats["sentCount"]) + t.Logf("丢弃数量: %d", stats["droppedCount"]) + t.Logf("缓冲区利用率: %d/%d", stats["bufferUsed"], stats["bufferSize"]) // 验证没有严重的数据丢失 if bufferSize >= 1000 { - if stats["dropped_count"] > int64(dataCount/10) { // 允许最多10%的丢失 - t.Errorf("丢失数据过多: %d (总数: %d)", stats["dropped_count"], dataCount) + if stats["droppedCount"] > int64(dataCount/10) { // 允许最多10%的丢失 + t.Errorf("丢失数据过多: %d (总数: %d)", stats["droppedCount"], dataCount) } } @@ -73,10 +70,7 @@ type TestData struct { func BenchmarkTumblingWindowThroughput(b *testing.B) { tw, _ := NewTumblingWindow(types.WindowConfig{ Type: "TumblingWindow", - Params: map[string]interface{}{ - "size": "10ms", - "outputBufferSize": 5000, - }, + Params: []interface{}{10 * time.Millisecond}, TsProp: "Ts", }) @@ -107,7 +101,7 @@ func BenchmarkTumblingWindowThroughput(b *testing.B) { // 获取最终统计 stats := tw.GetStats() - b.Logf("发送成功: %d, 丢弃: %d", stats["sent_count"], stats["dropped_count"]) + b.Logf("发送成功: %d, 丢弃: %d", stats["sentCount"], stats["droppedCount"]) tw.Stop() } @@ -117,10 +111,7 @@ func TestWindowBufferOverflow(t *testing.T) { // 创建一个小缓冲区的窗口 tw, _ := NewTumblingWindow(types.WindowConfig{ Type: "TumblingWindow", - Params: map[string]interface{}{ - "size": "50ms", - "outputBufferSize": 5, // 很小的缓冲区 - }, + Params: []interface{}{50 * time.Millisecond}, TsProp: "Ts", }) @@ -141,15 +132,15 @@ func TestWindowBufferOverflow(t *testing.T) { time.Sleep(200 * time.Millisecond) stats := tw.GetStats() - t.Logf("缓冲区溢出测试 - 发送: %d, 丢弃: %d", stats["sent_count"], stats["dropped_count"]) + t.Logf("缓冲区溢出测试 - 发送: %d, 丢弃: %d", stats["sentCount"], stats["droppedCount"]) // 应该有数据被丢弃 - if stats["dropped_count"] == 0 { + if stats["droppedCount"] == 0 { t.Log("预期会有数据丢弃,但实际没有丢弃") } // 验证系统仍然运行正常(没有阻塞) - if stats["sent_count"] == 0 { + if stats["sentCount"] == 0 { t.Error("应该至少发送了一些数据") } diff --git a/window/session_window.go b/window/session_window.go index 628f940..5550999 100644 --- a/window/session_window.go +++ b/window/session_window.go @@ -19,6 +19,7 @@ package window import ( "context" "fmt" + "strings" "sync" "time" @@ -68,19 +69,24 @@ type session struct { func NewSessionWindow(config types.WindowConfig) (*SessionWindow, error) { // Create a cancellable context ctx, cancel := context.WithCancel(context.Background()) - timeout, err := cast.ToDurationE(config.Params["timeout"]) + + // Get timeout parameter from params array + if len(config.Params) == 0 { + return nil, fmt.Errorf("session window requires 'timeout' parameter") + } + + timeoutVal := config.Params[0] + timeout, err := cast.ToDurationE(timeoutVal) if err != nil { return nil, fmt.Errorf("invalid timeout for session window: %v", err) } // Use unified performance configuration to get window output buffer size bufferSize := 100 // Default value, session windows typically have smaller buffers - if perfConfig, exists := config.Params["performanceConfig"]; exists { - if pc, ok := perfConfig.(types.PerformanceConfig); ok { - bufferSize = pc.BufferConfig.WindowOutputSize / 10 // Session window uses 1/10 of buffer - if bufferSize < 10 { - bufferSize = 10 // Minimum value - } + if (config.PerformanceConfig != types.PerformanceConfig{}) { + bufferSize = config.PerformanceConfig.BufferConfig.WindowOutputSize / 10 // Session window uses 1/10 of buffer + if bufferSize < 10 { + bufferSize = 10 // Minimum value } } @@ -115,9 +121,8 @@ func (sw *SessionWindow) Add(data interface{}) { Timestamp: timestamp, } - // Extract session key - // If groupby is configured, use groupby field as session key - key := extractSessionKey(data, sw.config.GroupByKey) + // Extract session key (supports multiple group by keys) + key := extractSessionCompositeKey(data, sw.config.GroupByKeys) // Get or create session s, exists := sw.sessionMap[key] @@ -208,7 +213,6 @@ func (sw *SessionWindow) Stop() { // checkExpiredSessions checks and triggers expired sessions func (sw *SessionWindow) checkExpiredSessions() { sw.mu.Lock() - defer sw.mu.Unlock() now := time.Now() expiredKeys := []string{} @@ -221,49 +225,74 @@ func (sw *SessionWindow) checkExpiredSessions() { } // Process expired sessions + resultsToSend := make([][]types.Row, 0) for _, key := range expiredKeys { s := sw.sessionMap[key] if len(s.data) > 0 { // Trigger session window result := make([]types.Row, len(s.data)) copy(result, s.data) - - // If callback function is set, execute it - if sw.callback != nil { - sw.callback(result) - } - - // Send data to output channel - sw.outputChan <- result + resultsToSend = append(resultsToSend, result) } // Delete expired session delete(sw.sessionMap, key) } + + // Release lock before sending to channel and calling callback to avoid blocking + sw.mu.Unlock() + + // Send results and call callbacks outside of lock to avoid blocking + for _, result := range resultsToSend { + // If callback function is set, execute it + if sw.callback != nil { + sw.callback(result) + } + + // Non-blocking send to output channel + select { + case sw.outputChan <- result: + // Successfully sent + default: + // Channel full, drop result (could add statistics here if needed) + } + } } // Trigger manually triggers all session windows func (sw *SessionWindow) Trigger() { sw.mu.Lock() - defer sw.mu.Unlock() - // Iterate through all sessions + // Collect all results first + resultsToSend := make([][]types.Row, 0) for _, s := range sw.sessionMap { if len(s.data) > 0 { // Trigger session window result := make([]types.Row, len(s.data)) copy(result, s.data) - - // If callback function is set, execute it - if sw.callback != nil { - sw.callback(result) - } - - // Send data to output channel - sw.outputChan <- result + resultsToSend = append(resultsToSend, result) } } // Clear all sessions sw.sessionMap = make(map[string]*session) + + // Release lock before sending to channel and calling callback to avoid blocking + sw.mu.Unlock() + + // Send results and call callbacks outside of lock to avoid blocking + for _, result := range resultsToSend { + // If callback function is set, execute it + if sw.callback != nil { + sw.callback(result) + } + + // Non-blocking send to output channel + select { + case sw.outputChan <- result: + // Successfully sent + default: + // Channel full, drop result (could add statistics here if needed) + } + } } // Reset resets session window data @@ -297,19 +326,18 @@ func (sw *SessionWindow) SetCallback(callback func([]types.Row)) { sw.callback = callback } -// extractSessionKey extracts session key from data -// If no key is specified, returns default key -func extractSessionKey(data interface{}, keyField string) string { - if keyField == "" { - return "default" // Default session key +// extractSessionCompositeKey builds composite session key from multiple group fields +// If GroupByKeys is empty, returns default key +func extractSessionCompositeKey(data interface{}, keys []string) string { + if len(keys) == 0 { + return "default" } - - // Try to extract from map + parts := make([]string, 0, len(keys)) if m, ok := data.(map[string]interface{}); ok { - if val, exists := m[keyField]; exists { - return fmt.Sprintf("%v", val) + for _, k := range keys { + parts = append(parts, fmt.Sprintf("%v", m[k])) } + return strings.Join(parts, "|") } - return "default" } diff --git a/window/sliding_window.go b/window/sliding_window.go index 67ae00a..fcce7fd 100644 --- a/window/sliding_window.go +++ b/window/sliding_window.go @@ -73,25 +73,36 @@ type SlidingWindow struct { // NewSlidingWindow creates a new sliding window instance // size parameter represents the total window size, slide represents the sliding interval func NewSlidingWindow(config types.WindowConfig) (*SlidingWindow, error) { - // Create a cancellable context - ctx, cancel := context.WithCancel(context.Background()) - size, err := cast.ToDurationE(config.Params["size"]) + // Get size parameter from params array + if len(config.Params) < 1 { + return nil, fmt.Errorf("sliding window requires at least 'size' parameter") + } + + sizeVal := config.Params[0] + size, err := cast.ToDurationE(sizeVal) if err != nil { return nil, fmt.Errorf("invalid size for sliding window: %v", err) } - slide, err := cast.ToDurationE(config.Params["slide"]) + + // Get slide parameter from params array + if len(config.Params) < 2 { + return nil, fmt.Errorf("sliding window requires 'slide' parameter") + } + + slideVal := config.Params[1] + slide, err := cast.ToDurationE(slideVal) if err != nil { return nil, fmt.Errorf("invalid slide for sliding window: %v", err) } // Use unified performance config to get window output buffer size bufferSize := 1000 // Default value - if perfConfig, exists := config.Params["performanceConfig"]; exists { - if pc, ok := perfConfig.(types.PerformanceConfig); ok { - bufferSize = pc.BufferConfig.WindowOutputSize - } + if (config.PerformanceConfig != types.PerformanceConfig{}) { + bufferSize = config.PerformanceConfig.BufferConfig.WindowOutputSize } + // Create a cancellable context + ctx, cancel := context.WithCancel(context.Background()) return &SlidingWindow{ config: config, size: size, @@ -195,29 +206,24 @@ func (sw *SlidingWindow) Stop() { func (sw *SlidingWindow) Trigger() { // Lock to ensure thread safety sw.mu.Lock() - defer sw.mu.Unlock() // Return directly if no data in window if len(sw.data) == 0 { + sw.mu.Unlock() return } if !sw.initialized { + sw.mu.Unlock() return } - // Calculate cutoff time (current time minus window size) + // Calculate next slot for sliding window next := sw.NextSlot() - // Retain data for next window - tms := next.Start.Add(-sw.size) - tme := next.End.Add(sw.size) - temp := types.NewTimeSlot(&tms, &tme) - newData := make([]types.Row, 0) - for _, item := range sw.data { - if temp.Contains(item.Timestamp) { - newData = append(newData, item) - } + if next == nil { + sw.mu.Unlock() + return } - // Extract Data fields to form []interface{} type data + // Extract Data fields to form []interface{} type data for current window resultData := make([]types.Row, 0) for _, item := range sw.data { if sw.currentSlot.Contains(item.Timestamp) { @@ -226,24 +232,55 @@ func (sw *SlidingWindow) Trigger() { } } - // Execute callback function if set - if sw.callback != nil { - sw.callback(resultData) + // Retain data that could be in future windows + // For sliding windows, we need to keep data that falls within: + // - Current window end + size (for overlapping windows) + // - Next window end + size (for future windows) + // Actually, we should keep all data that could be in any future window + // The latest window that could contain a data point is: next.End + size + cutoffTime := next.End.Add(sw.size) + newData := make([]types.Row, 0) + for _, item := range sw.data { + // Keep data that could be in future windows (before cutoffTime) + if item.Timestamp.Before(cutoffTime) { + newData = append(newData, item) + } } // Update window data sw.data = newData sw.currentSlot = next - // Non-blocking send to output channel and update statistics (within lock) + // Get callback reference before releasing lock + callback := sw.callback + + // Release lock before calling callback and sending to channel to avoid blocking + sw.mu.Unlock() + + // Execute callback function if set (outside of lock to avoid blocking) + if callback != nil { + callback(resultData) + } + + // Non-blocking send to output channel and update statistics + var sent bool select { case sw.outputChan <- resultData: - // Successfully sent, update statistics (within lock) - sw.sentCount++ + // Successfully sent + sent = true default: - // Channel full, drop result and update statistics (within lock) + // Channel full, drop result + sent = false + } + + // Re-acquire lock to update statistics + sw.mu.Lock() + if sent { + sw.sentCount++ + } else { sw.droppedCount++ } + sw.mu.Unlock() } // GetStats returns window performance statistics @@ -252,10 +289,10 @@ func (sw *SlidingWindow) GetStats() map[string]int64 { defer sw.mu.RUnlock() return map[string]int64{ - "sent_count": sw.sentCount, - "dropped_count": sw.droppedCount, - "buffer_size": int64(cap(sw.outputChan)), - "buffer_used": int64(len(sw.outputChan)), + "sentCount": sw.sentCount, + "droppedCount": sw.droppedCount, + "bufferSize": int64(cap(sw.outputChan)), + "bufferUsed": int64(len(sw.outputChan)), } } diff --git a/window/sliding_window_test.go b/window/sliding_window_test.go index 05abb36..9b203e4 100644 --- a/window/sliding_window_test.go +++ b/window/sliding_window_test.go @@ -22,10 +22,7 @@ func TestSlidingWindow(t *testing.T) { defer cancel() sw, _ := NewSlidingWindow(types.WindowConfig{ - Params: map[string]interface{}{ - "size": "2s", - "slide": "1s", - }, + Params: []interface{}{2 * time.Second, time.Second}, TsProp: "Ts", TimeUnit: time.Second, }) diff --git a/window/tumbling_window.go b/window/tumbling_window.go index 2b86560..6536221 100644 --- a/window/tumbling_window.go +++ b/window/tumbling_window.go @@ -66,17 +66,22 @@ type TumblingWindow struct { func NewTumblingWindow(config types.WindowConfig) (*TumblingWindow, error) { // Create a cancellable context ctx, cancel := context.WithCancel(context.Background()) - size, err := cast.ToDurationE(config.Params["size"]) + + // Get size parameter from params array + if len(config.Params) == 0 { + return nil, fmt.Errorf("tumbling window requires 'size' parameter") + } + + sizeVal := config.Params[0] + size, err := cast.ToDurationE(sizeVal) if err != nil { return nil, fmt.Errorf("invalid size for tumbling window: %v", err) } // Use unified performance config to get window output buffer size bufferSize := 1000 // Default value - if perfConfig, exists := config.Params["performanceConfig"]; exists { - if pc, ok := perfConfig.(types.PerformanceConfig); ok { - bufferSize = pc.BufferConfig.WindowOutputSize - } + if (config.PerformanceConfig != types.PerformanceConfig{}) { + bufferSize = config.PerformanceConfig.BufferConfig.WindowOutputSize } return &TumblingWindow{ @@ -196,9 +201,9 @@ func (tw *TumblingWindow) Start() { func (tw *TumblingWindow) Trigger() { // Lock to ensure thread safety tw.mu.Lock() - defer tw.mu.Unlock() if !tw.initialized { + tw.mu.Unlock() return } // Calculate next window slot @@ -223,26 +228,41 @@ func (tw *TumblingWindow) Trigger() { } } - // Execute callback function if set - if tw.callback != nil { - tw.callback(resultData) - } - // Update window data tw.data = newData tw.currentSlot = next + // Get callback reference before releasing lock + callback := tw.callback + + // Release lock before calling callback and sending to channel to avoid blocking + tw.mu.Unlock() + + if callback != nil { + callback(resultData) + } + // Non-blocking send to output channel and update statistics + var sent bool select { case tw.outputChan <- resultData: - // Successfully sent, update statistics (within lock) - tw.sentCount++ + // Successfully sent + sent = true default: - // Channel full, drop result and update statistics (within lock) - tw.droppedCount++ + // Channel full, drop result + sent = false + } + // Re-acquire lock to update statistics + tw.mu.Lock() + if sent { + tw.sentCount++ + } else { + tw.droppedCount++ // Optional: add logging here - } // log.Printf("Window output channel full, dropped result with %d rows", len(resultData)) + // log.Printf("Window output channel full, dropped result with %d rows", len(resultData)) + } + tw.mu.Unlock() } // Reset resets tumbling window data @@ -290,10 +310,10 @@ func (tw *TumblingWindow) GetStats() map[string]int64 { defer tw.mu.RUnlock() return map[string]int64{ - "sent_count": tw.sentCount, - "dropped_count": tw.droppedCount, - "buffer_size": int64(cap(tw.outputChan)), - "buffer_used": int64(len(tw.outputChan)), + "sentCount": tw.sentCount, + "droppedCount": tw.droppedCount, + "bufferSize": int64(cap(tw.outputChan)), + "bufferUsed": int64(len(tw.outputChan)), } } diff --git a/window/tumbling_window_test.go b/window/tumbling_window_test.go index bf097ac..757e9e2 100644 --- a/window/tumbling_window_test.go +++ b/window/tumbling_window_test.go @@ -16,7 +16,7 @@ func TestTumblingWindow(t *testing.T) { tw, _ := NewTumblingWindow(types.WindowConfig{ Type: "TumblingWindow", - Params: map[string]interface{}{"size": "2s"}, + Params: []interface{}{2 * time.Second}, TsProp: "Ts", }) tw.SetCallback(func(results []types.Row) { diff --git a/window/window_test.go b/window/window_test.go index 32759b9..68c98a8 100644 --- a/window/window_test.go +++ b/window/window_test.go @@ -2,6 +2,7 @@ package window import ( "reflect" + "strconv" "sync" "testing" "time" @@ -23,9 +24,7 @@ func getTypeString(obj interface{}) string { func TestWindowEdgeCases(t *testing.T) { t.Run("tumbling window with zero duration", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Duration(0), - }, + Params: []interface{}{time.Duration(0)}, } _, err := NewTumblingWindow(config) // 零持续时间可能是有效的,取决于实现 @@ -34,9 +33,7 @@ func TestWindowEdgeCases(t *testing.T) { t.Run("tumbling window with negative duration", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": -time.Second, - }, + Params: []interface{}{-time.Second}, } _, err := NewTumblingWindow(config) // 负持续时间可能是有效的,取决于实现 @@ -45,10 +42,7 @@ func TestWindowEdgeCases(t *testing.T) { t.Run("sliding window with zero window size", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Duration(0), - "slide": time.Second, - }, + Params: []interface{}{time.Duration(0), time.Second}, } _, err := NewSlidingWindow(config) // 零滑动间隔可能是有效的,取决于实现 @@ -57,10 +51,7 @@ func TestWindowEdgeCases(t *testing.T) { t.Run("sliding window with zero slide interval", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Minute, - "slide": time.Duration(0), - }, + Params: []interface{}{time.Minute, time.Duration(0)}, } _, err := NewSlidingWindow(config) // 零滑动间隔可能是有效的,取决于实现 @@ -70,10 +61,7 @@ func TestWindowEdgeCases(t *testing.T) { t.Run("sliding window with slide larger than window", func(t *testing.T) { // 这种情况可能是有效的,取决于实现 config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Second, - "slide": time.Minute, - }, + Params: []interface{}{time.Second, time.Minute}, } window, err := NewSlidingWindow(config) _ = window @@ -82,9 +70,7 @@ func TestWindowEdgeCases(t *testing.T) { t.Run("counting window with zero count", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "count": 0, - }, + Params: []interface{}{0}, } _, err := NewCountingWindow(config) require.NotNil(t, err) @@ -92,9 +78,7 @@ func TestWindowEdgeCases(t *testing.T) { t.Run("counting window with negative count", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "count": -10, - }, + Params: []interface{}{-10}, } _, err := NewCountingWindow(config) require.NotNil(t, err) @@ -102,9 +86,7 @@ func TestWindowEdgeCases(t *testing.T) { t.Run("session window with zero timeout", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "timeout": time.Duration(0), - }, + Params: []interface{}{time.Duration(0)}, } _, err := NewSessionWindow(config) // 零超时可能是有效的,取决于实现 @@ -113,9 +95,7 @@ func TestWindowEdgeCases(t *testing.T) { t.Run("session window with negative timeout", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "timeout": -time.Second, - }, + Params: []interface{}{-time.Second}, } _, err := NewSessionWindow(config) // 负超时可能是有效的,取决于实现 @@ -127,9 +107,7 @@ func TestWindowEdgeCases(t *testing.T) { func TestWindowWithNilCallback(t *testing.T) { t.Run("tumbling window with nil callback", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Second, - }, + Params: []interface{}{time.Second}, } window, err := NewTumblingWindow(config) if err == nil { @@ -147,10 +125,7 @@ func TestWindowWithNilCallback(t *testing.T) { t.Run("sliding window with nil callback", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Minute, - "slide": time.Second, - }, + Params: []interface{}{time.Minute, time.Second}, } window, err := NewSlidingWindow(config) if err == nil { @@ -167,9 +142,7 @@ func TestWindowWithNilCallback(t *testing.T) { t.Run("counting window with nil callback", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "count": 10, - }, + Params: []interface{}{10}, } window, err := NewCountingWindow(config) if err == nil { @@ -186,9 +159,7 @@ func TestWindowWithNilCallback(t *testing.T) { t.Run("session window with nil callback", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "timeout": time.Minute, - }, + Params: []interface{}{time.Minute}, } window, err := NewSessionWindow(config) if err == nil { @@ -217,9 +188,7 @@ func TestWindowConcurrency(t *testing.T) { } config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Millisecond * 100, - }, + Params: []interface{}{time.Millisecond * 100}, } window, err := NewTumblingWindow(config) if err == nil { @@ -258,9 +227,7 @@ func TestWindowConcurrency(t *testing.T) { t.Run("concurrent start stop", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Second, - }, + Params: []interface{}{time.Second}, } window, err := NewTumblingWindow(config) if err == nil { @@ -286,9 +253,7 @@ func TestWindowConcurrency(t *testing.T) { t.Run("concurrent add and stop", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Second, - }, + Params: []interface{}{time.Second}, } window, err := NewTumblingWindow(config) if err == nil { @@ -335,9 +300,7 @@ func TestWindowMemoryManagement(t *testing.T) { } config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Millisecond * 50, - }, + Params: []interface{}{time.Millisecond * 50}, } window, err := NewTumblingWindow(config) if err == nil { @@ -379,9 +342,7 @@ func TestWindowMemoryManagement(t *testing.T) { } config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Millisecond * 10, - }, + Params: []interface{}{time.Millisecond * 10}, } window, err := NewTumblingWindow(config) if err == nil { @@ -409,9 +370,7 @@ func TestWindowMemoryManagement(t *testing.T) { func TestWindowErrorConditions(t *testing.T) { t.Run("add to stopped window", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Second, - }, + Params: []interface{}{time.Second}, } window, err := NewTumblingWindow(config) if err == nil { @@ -431,9 +390,7 @@ func TestWindowErrorConditions(t *testing.T) { t.Run("add invalid data types", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Second, - }, + Params: []interface{}{time.Second}, } window, err := NewTumblingWindow(config) if err == nil { @@ -457,9 +414,7 @@ func TestWindowErrorConditions(t *testing.T) { t.Run("add row with zero timestamp", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Second, - }, + Params: []interface{}{time.Second}, } window, err := NewTumblingWindow(config) if err == nil { @@ -479,9 +434,7 @@ func TestWindowErrorConditions(t *testing.T) { t.Run("add row with future timestamp", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Second, - }, + Params: []interface{}{time.Second}, } window, err := NewTumblingWindow(config) if err == nil { @@ -501,9 +454,7 @@ func TestWindowErrorConditions(t *testing.T) { t.Run("add row with very old timestamp", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Second, - }, + Params: []interface{}{time.Second}, } window, err := NewTumblingWindow(config) if err == nil { @@ -526,9 +477,7 @@ func TestWindowErrorConditions(t *testing.T) { func TestWindowStatsAndMetrics(t *testing.T) { t.Run("get stats from tumbling window", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Second, - }, + Params: []interface{}{time.Second}, } window, err := NewTumblingWindow(config) if err == nil { @@ -543,9 +492,7 @@ func TestWindowStatsAndMetrics(t *testing.T) { t.Run("reset stats", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Second, - }, + Params: []interface{}{time.Second}, } window, err := NewTumblingWindow(config) if err == nil { @@ -568,9 +515,7 @@ func TestWindowStatsAndMetrics(t *testing.T) { t.Run("get output channel", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Second, - }, + Params: []interface{}{time.Second}, } window, err := NewTumblingWindow(config) if err == nil { @@ -585,9 +530,7 @@ func TestWindowStatsAndMetrics(t *testing.T) { t.Run("set callback", func(t *testing.T) { config := types.WindowConfig{ - Params: map[string]interface{}{ - "size": time.Second, - }, + Params: []interface{}{time.Second}, } window, err := NewTumblingWindow(config) if err == nil { @@ -655,16 +598,48 @@ func TestWindowWithPerformanceConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - config := types.WindowConfig{ - Type: tt.windowType, - Params: make(map[string]interface{}), + // Convert extraParams to array format + var params []interface{} + if tt.windowType == TypeCounting { + if count, ok := tt.extraParams["count"].(int); ok { + params = []interface{}{count} + } else if countStr, ok := tt.extraParams["count"].(string); ok { + if count, err := strconv.Atoi(countStr); err == nil { + params = []interface{}{count} + } + } + } else if tt.windowType == TypeSession { + if timeout, ok := tt.extraParams["timeout"].(string); ok { + if dur, err := time.ParseDuration(timeout); err == nil { + params = []interface{}{dur} + } + } + } else if tt.windowType == TypeSliding { + var size, slide time.Duration + if sizeStr, ok := tt.extraParams["size"].(string); ok { + if dur, err := time.ParseDuration(sizeStr); err == nil { + size = dur + } + } + if slideStr, ok := tt.extraParams["slide"].(string); ok { + if dur, err := time.ParseDuration(slideStr); err == nil { + slide = dur + } + } + params = []interface{}{size, slide} + } else { + if sizeStr, ok := tt.extraParams["size"].(string); ok { + if dur, err := time.ParseDuration(sizeStr); err == nil { + params = []interface{}{dur} + } + } } - - // 合并参数 - for k, v := range tt.extraParams { - config.Params[k] = v + + config := types.WindowConfig{ + Type: tt.windowType, + Params: params, + PerformanceConfig: tt.performanceConfig, } - config.Params["performanceConfig"] = tt.performanceConfig var window Window var err error @@ -692,9 +667,7 @@ func TestWindowWithPerformanceConfig(t *testing.T) { t.Run("无性能配置-使用默认值", func(t *testing.T) { config := types.WindowConfig{ Type: TypeTumbling, - Params: map[string]interface{}{ - "size": "3s", - }, + Params: []interface{}{3 * time.Second}, } tw, err := NewTumblingWindow(config) @@ -780,10 +753,8 @@ func TestGetTimestampEdgeCases(t *testing.T) { func TestSessionWindowSessionKey(t *testing.T) { config := types.WindowConfig{ Type: TypeSession, - Params: map[string]interface{}{ - "timeout": "5s", - }, - GroupByKey: "user_id", + Params: []interface{}{5 * time.Second}, + GroupByKeys: []string{"user_id"}, } sw, err := NewSessionWindow(config) @@ -834,31 +805,28 @@ func TestWindowStopBeforeStart(t *testing.T) { name: "滚动窗口", config: types.WindowConfig{ Type: TypeTumbling, - Params: map[string]interface{}{"size": "1s"}, + Params: []interface{}{time.Second}, }, }, { name: "滑动窗口", config: types.WindowConfig{ Type: TypeSliding, - Params: map[string]interface{}{ - "size": "2s", - "slide": "1s", - }, + Params: []interface{}{2 * time.Second, time.Second}, }, }, { name: "计数窗口", config: types.WindowConfig{ Type: TypeCounting, - Params: map[string]interface{}{"count": 10}, + Params: []interface{}{10}, }, }, { name: "会话窗口", config: types.WindowConfig{ Type: TypeSession, - Params: map[string]interface{}{"timeout": "5s"}, + Params: []interface{}{5 * time.Second}, }, }, } @@ -889,7 +857,7 @@ func TestWindowStopBeforeStart(t *testing.T) { func TestWindowMultipleStops(t *testing.T) { config := types.WindowConfig{ Type: TypeTumbling, - Params: map[string]interface{}{"size": "1s"}, + Params: []interface{}{time.Second}, } tw, err := NewTumblingWindow(config) @@ -909,7 +877,7 @@ func TestWindowMultipleStops(t *testing.T) { func TestWindowAddAfterStop(t *testing.T) { config := types.WindowConfig{ Type: TypeTumbling, - Params: map[string]interface{}{"size": "1s"}, + Params: []interface{}{time.Second}, } tw, err := NewTumblingWindow(config) @@ -935,11 +903,9 @@ func TestCountingWindowWithCallback(t *testing.T) { } config := types.WindowConfig{ - Type: TypeCounting, - Params: map[string]interface{}{ - "count": 2, - "callback": callback, - }, + Type: TypeCounting, + Params: []interface{}{2}, + Callback: callback, } cw, err := NewCountingWindow(config) @@ -967,20 +933,15 @@ func TestCountingWindowWithCallback(t *testing.T) { func TestSlidingWindowInvalidParams(t *testing.T) { tests := []struct { name string - params map[string]interface{} + params []interface{} }{ { - name: "无效的slide参数", - params: map[string]interface{}{ - "size": "10s", - "slide": "invalid", - }, + name: "无效的slide参数", + params: []interface{}{10 * time.Second, "invalid"}, }, { - name: "缺少slide参数", - params: map[string]interface{}{ - "size": "10s", - }, + name: "缺少slide参数", + params: []interface{}{10 * time.Second}, }, } @@ -1002,11 +963,9 @@ func TestWindowUnifiedConfigIntegration(t *testing.T) { performanceConfig := types.HighPerformanceConfig() windowConfig := types.WindowConfig{ - Type: TypeTumbling, - Params: map[string]interface{}{ - "size": "1s", - "performanceConfig": performanceConfig, - }, + Type: TypeTumbling, + Params: []interface{}{time.Second}, + PerformanceConfig: performanceConfig, } tw, err := NewTumblingWindow(windowConfig) @@ -1049,11 +1008,9 @@ func TestWindowUnifiedConfigIntegration(t *testing.T) { } config := types.WindowConfig{ - Type: TypeTumbling, - Params: map[string]interface{}{ - "size": "100ms", - "performanceConfig": smallBufferConfig, - }, + Type: TypeTumbling, + Params: []interface{}{100 * time.Millisecond}, + PerformanceConfig: smallBufferConfig, } tw, err := NewTumblingWindow(config) @@ -1072,8 +1029,8 @@ func TestWindowUnifiedConfigIntegration(t *testing.T) { // 检查统计信息 stats := tw.GetStats() - assert.Contains(t, stats, "dropped_count") - assert.Contains(t, stats, "sent_count") + assert.Contains(t, stats, "droppedCount") + assert.Contains(t, stats, "sentCount") }) } @@ -1088,10 +1045,8 @@ func TestCreateWindow(t *testing.T) { { name: "创建滚动窗口", config: types.WindowConfig{ - Type: TypeTumbling, - Params: map[string]interface{}{ - "size": "5s", - }, + Type: TypeTumbling, + Params: []interface{}{5 * time.Second}, }, expectError: false, expectedType: "*window.TumblingWindow", @@ -1099,11 +1054,8 @@ func TestCreateWindow(t *testing.T) { { name: "创建滑动窗口", config: types.WindowConfig{ - Type: TypeSliding, - Params: map[string]interface{}{ - "size": "10s", - "slide": "5s", - }, + Type: TypeSliding, + Params: []interface{}{10 * time.Second, 5 * time.Second}, }, expectError: false, expectedType: "*window.SlidingWindow", @@ -1111,10 +1063,8 @@ func TestCreateWindow(t *testing.T) { { name: "创建计数窗口", config: types.WindowConfig{ - Type: TypeCounting, - Params: map[string]interface{}{ - "count": 100, - }, + Type: TypeCounting, + Params: []interface{}{100}, }, expectError: false, expectedType: "*window.CountingWindow", @@ -1122,10 +1072,8 @@ func TestCreateWindow(t *testing.T) { { name: "创建会话窗口", config: types.WindowConfig{ - Type: TypeSession, - Params: map[string]interface{}{ - "timeout": "30s", - }, + Type: TypeSession, + Params: []interface{}{30 * time.Second}, }, expectError: false, expectedType: "*window.SessionWindow", @@ -1134,12 +1082,10 @@ func TestCreateWindow(t *testing.T) { name: "窗口工厂与统一配置集成", config: types.WindowConfig{ Type: TypeTumbling, - Params: map[string]interface{}{ - "size": "5s", - "performanceConfig": types.PerformanceConfig{ + Params: []interface{}{5 * time.Second}, + PerformanceConfig: types.PerformanceConfig{ BufferConfig: types.BufferConfig{ WindowOutputSize: 1500, - }, }, }, }, @@ -1149,10 +1095,8 @@ func TestCreateWindow(t *testing.T) { { name: "无效的窗口类型", config: types.WindowConfig{ - Type: "invalid", - Params: map[string]interface{}{ - "size": "5s", - }, + Type: "invalid", + Params: []interface{}{5 * time.Second}, }, expectError: true, expectedType: "", @@ -1253,10 +1197,8 @@ func TestGetTimestampCoverage(t *testing.T) { func TestWindowErrorHandling(t *testing.T) { t.Run("滚动窗口无效大小", func(t *testing.T) { config := types.WindowConfig{ - Type: TypeTumbling, - Params: map[string]interface{}{ - "size": "invalid", - }, + Type: TypeTumbling, + Params: []interface{}{"invalid"}, } _, err := NewTumblingWindow(config) assert.Error(t, err) @@ -1264,11 +1206,8 @@ func TestWindowErrorHandling(t *testing.T) { t.Run("滑动窗口无效参数", func(t *testing.T) { config := types.WindowConfig{ - Type: TypeSliding, - Params: map[string]interface{}{ - "size": "invalid", - "slide": "5s", - }, + Type: TypeSliding, + Params: []interface{}{"invalid", 5 * time.Second}, } _, err := NewSlidingWindow(config) assert.Error(t, err) @@ -1276,10 +1215,8 @@ func TestWindowErrorHandling(t *testing.T) { t.Run("计数窗口无效计数", func(t *testing.T) { config := types.WindowConfig{ - Type: TypeCounting, - Params: map[string]interface{}{ - "count": 0, - }, + Type: TypeCounting, + Params: []interface{}{0}, } _, err := NewCountingWindow(config) assert.Error(t, err) @@ -1287,10 +1224,8 @@ func TestWindowErrorHandling(t *testing.T) { t.Run("会话窗口无效超时", func(t *testing.T) { config := types.WindowConfig{ - Type: TypeSession, - Params: map[string]interface{}{ - "timeout": "invalid", - }, + Type: TypeSession, + Params: []interface{}{"invalid"}, } _, err := NewSessionWindow(config) assert.Error(t, err) @@ -1301,10 +1236,8 @@ func TestWindowErrorHandling(t *testing.T) { func TestSessionWindowAdvanced(t *testing.T) { config := types.WindowConfig{ Type: TypeSession, - Params: map[string]interface{}{ - "timeout": "1s", - }, - GroupByKey: "user_id", + Params: []interface{}{time.Second}, + GroupByKeys: []string{"user_id"}, } sw, err := NewSessionWindow(config) @@ -1352,11 +1285,8 @@ func TestSessionWindowAdvanced(t *testing.T) { // TestSlidingWindowAdvanced 测试滑动窗口的高级功能 func TestSlidingWindowAdvanced(t *testing.T) { config := types.WindowConfig{ - Type: TypeSliding, - Params: map[string]interface{}{ - "size": "2s", - "slide": "1s", - }, + Type: TypeSliding, + Params: []interface{}{2 * time.Second, time.Second}, TsProp: "timestamp", TimeUnit: time.Second, } @@ -1379,10 +1309,8 @@ func TestSlidingWindowAdvanced(t *testing.T) { // TestCountingWindowAdvanced 测试计数窗口的高级功能 func TestCountingWindowAdvanced(t *testing.T) { config := types.WindowConfig{ - Type: TypeCounting, - Params: map[string]interface{}{ - "count": 3, - }, + Type: TypeCounting, + Params: []interface{}{3}, TsProp: "timestamp", TimeUnit: time.Second, } @@ -1430,9 +1358,7 @@ func TestCountingWindowAdvanced(t *testing.T) { func TestTumblingWindowAdvanced(t *testing.T) { config := types.WindowConfig{ Type: TypeTumbling, - Params: map[string]interface{}{ - "size": "1s", - }, + Params: []interface{}{time.Second}, TsProp: "timestamp", TimeUnit: time.Second, } @@ -1443,8 +1369,8 @@ func TestTumblingWindowAdvanced(t *testing.T) { // 检查统计信息 stats := tw.GetStats() - assert.Contains(t, stats, "sent_count") - assert.Contains(t, stats, "dropped_count") + assert.Contains(t, stats, "sentCount") + assert.Contains(t, stats, "droppedCount") // 测试重置统计信息 tw.ResetStats()