Skip to content
Merged

Dev #40

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: CI

on:
push:
branches: [ main, master, develop ]
branches: [ main, dev ]
pull_request:
branches: [ main, master, develop ]
branches: [ main, dev ]

jobs:
test:
Expand Down
183 changes: 121 additions & 62 deletions rsql/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/rulego/streamsql/functions"
"github.com/rulego/streamsql/types"
"github.com/rulego/streamsql/utils/cast"
"github.com/rulego/streamsql/window"

"github.com/rulego/streamsql/aggregator"
Expand Down Expand Up @@ -58,9 +59,16 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) {
windowType = window.TypeSession
}

params, err := parseWindowParamsWithType(s.Window.Params, windowType)
if err != nil {
return nil, "", fmt.Errorf("failed to parse window parameters: %w", err)
// Parse window parameters - now returns array directly
params := s.Window.Params

// Validate and convert parameters based on window type
if len(params) > 0 {
var err error
params, err = validateWindowParams(params, windowType)
if err != nil {
return nil, "", fmt.Errorf("failed to validate window parameters: %w", err)
}
}

// Check if window processing is needed
Expand All @@ -80,16 +88,7 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) {
if !needWindow && hasAggregation {
needWindow = true
windowType = window.TypeTumbling
params = map[string]interface{}{
"size": 10 * time.Second, // Default 10-second window
}
}

// Handle special configuration for SessionWindow
var groupByKey string
if windowType == window.TypeSession && len(s.GroupBy) > 0 {
// For session window, use the first GROUP BY field as session key
groupByKey = s.GroupBy[0]
params = []interface{}{10 * time.Second} // Default 10-second window
}

// If no aggregation functions, collect simple fields
Expand All @@ -105,10 +104,10 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) {
simpleFields = append(simpleFields, fieldName+":"+field.Alias)
} else {
// For fields without alias, check if it's a string literal
_, n, _, _, err := ParseAggregateTypeWithExpression(fieldName)
if err != nil {
return nil, "", err
}
_, n, _, _, err := ParseAggregateTypeWithExpression(fieldName)
if err != nil {
return nil, "", err
}
if n != "" {
// If string literal, use parsed field name (remove quotes)
simpleFields = append(simpleFields, n)
Expand Down Expand Up @@ -137,11 +136,11 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) {
// Build Stream configuration
config := types.Config{
WindowConfig: types.WindowConfig{
Type: windowType,
Params: params,
TsProp: s.Window.TsProp,
TimeUnit: s.Window.TimeUnit,
GroupByKey: groupByKey,
Type: windowType,
Params: params,
TsProp: s.Window.TsProp,
TimeUnit: s.Window.TimeUnit,
GroupByKeys: extractGroupFields(s),
},
GroupFields: extractGroupFields(s),
SelectFields: aggs,
Expand Down Expand Up @@ -245,9 +244,9 @@ func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateTy
for _, f := range fields {
if alias := f.Alias; alias != "" {
t, n, _, _, parseErr := ParseAggregateTypeWithExpression(f.Expression)
if parseErr != nil {
return nil, nil, parseErr
}
if parseErr != nil {
return nil, nil, parseErr
}
if t != "" {
// Use alias as key for aggregator, not field name
selectFields[alias] = t
Expand Down Expand Up @@ -287,11 +286,11 @@ func detectNestedAggregationRecursive(expr string, inAggregation bool) error {
// 使用正则表达式匹配函数调用模式
pattern := regexp.MustCompile(`(?i)([a-z_]+)\s*\(`)
matches := pattern.FindAllStringSubmatchIndex(expr, -1)

for _, match := range matches {
funcStart := match[0]
funcName := strings.ToLower(expr[match[2]:match[3]])

// 检查函数是否为聚合函数
if fn, exists := functions.Get(funcName); exists {
switch fn.GetType() {
Expand All @@ -300,14 +299,14 @@ func detectNestedAggregationRecursive(expr string, inAggregation bool) error {
if inAggregation {
return fmt.Errorf("aggregate function calls cannot be nested")
}

// 找到该函数的参数部分
funcEnd := findMatchingParenInternal(expr, funcStart+len(funcName))
if funcEnd > funcStart {
// 提取函数参数
paramStart := funcStart + len(funcName) + 1
params := expr[paramStart:funcEnd]

// 在聚合函数参数内部递归检查
if err := detectNestedAggregationRecursive(params, true); err != nil {
return err
Expand All @@ -316,7 +315,7 @@ func detectNestedAggregationRecursive(expr string, inAggregation bool) error {
}
}
}

return nil
}

Expand Down Expand Up @@ -697,43 +696,103 @@ func extractSimpleField(fieldExpr string) string {
return fieldExpr
}

func parseWindowParams(params []interface{}) (map[string]interface{}, error) {
return parseWindowParamsWithType(params, "")
}
// validateWindowParams validates and converts window parameters based on window type
// Returns validated parameters array with proper types
func validateWindowParams(params []interface{}, windowType string) ([]interface{}, error) {
if len(params) == 0 {
return params, nil
}

func parseWindowParamsWithType(params []interface{}, windowType string) (map[string]interface{}, error) {
result := make(map[string]interface{})
var key string
validated := make([]interface{}, 0, len(params))

if windowType == window.TypeCounting {
// CountingWindow expects integer count as first parameter
if len(params) == 0 {
return nil, fmt.Errorf("counting window requires at least one parameter")
}

// Convert first parameter to int using cast utility
count, err := cast.ToIntE(params[0])
if err != nil {
return nil, fmt.Errorf("invalid count parameter: %w", err)
}

if count <= 0 {
return nil, fmt.Errorf("counting window count must be positive, got: %d", count)
}

validated = append(validated, count)

// Add any additional parameters
if len(params) > 1 {
validated = append(validated, params[1:]...)
}

return validated, nil
}

// Helper function to convert a value to time.Duration
// For numeric types, treats them as seconds
// For strings, uses time.ParseDuration
convertToDuration := func(val interface{}) (time.Duration, error) {
switch v := val.(type) {
case time.Duration:
return v, nil
case string:
// Use ToDurationE which handles string parsing
return cast.ToDurationE(v)
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
// Treat numeric integers as seconds
return time.Duration(cast.ToInt(v)) * time.Second, nil
case float32, float64:
// Treat numeric floats as seconds
return time.Duration(int(cast.ToFloat64(v))) * time.Second, nil
default:
// Try ToDurationE as fallback
return cast.ToDurationE(v)
}
}

if windowType == window.TypeSession {
// SessionWindow expects timeout duration as first parameter
if len(params) == 0 {
return nil, fmt.Errorf("session window requires at least one parameter")
}

timeout, err := convertToDuration(params[0])
if err != nil {
return nil, fmt.Errorf("invalid timeout duration: %w", err)
}

if timeout <= 0 {
return nil, fmt.Errorf("session window timeout must be positive, got: %v", timeout)
}

validated = append(validated, timeout)

// Add any additional parameters
if len(params) > 1 {
validated = append(validated, params[1:]...)
}

return validated, nil
}

// For TumblingWindow and SlidingWindow, convert parameters to time.Duration
for index, v := range params {
if windowType == window.TypeSession {
// First parameter for SessionWindow is timeout
if index == 0 {
key = "timeout"
} else {
key = fmt.Sprintf("param%d", index)
}
} else {
// Parameters for other window types
if index == 0 {
key = "size"
} else if index == 1 {
key = "slide"
} else {
key = "offset"
}
dur, err := convertToDuration(v)
if err != nil {
return nil, fmt.Errorf("invalid duration parameter at index %d: %w", index, err)
}
if s, ok := v.(string); ok {
dur, err := time.ParseDuration(s)
if err != nil {
return nil, fmt.Errorf("invalid %s duration: %w", s, err)
}
result[key] = dur
} else {
return nil, fmt.Errorf("%s parameter must be string format (like '5s')", s)

if dur <= 0 {
return nil, fmt.Errorf("duration parameter at index %d must be positive, got: %v", index, dur)
}

validated = append(validated, dur)
}

return result, nil
return validated, nil
}

func parseAggregateExpression(expr string) string {
Expand Down Expand Up @@ -958,7 +1017,7 @@ func parseComplexAggExpressionInternal(expr string) ([]types.AggregationFieldInf
if err := detectNestedAggregation(expr); err != nil {
return nil, "", err
}

// 使用改进的递归解析方法
aggFields, exprTemplate := parseNestedFunctionsInternal(expr, make([]types.AggregationFieldInfo, 0))
return aggFields, exprTemplate, nil
Expand Down
4 changes: 2 additions & 2 deletions rsql/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ func TestSelectStatementEdgeCases(t *testing.T) {
if config2.WindowConfig.Type != window.TypeSession {
t.Errorf("Expected session window, got %v", config2.WindowConfig.Type)
}
if config2.WindowConfig.GroupByKey != "user_id" {
t.Errorf("Expected GroupByKey to be 'user_id', got %s", config2.WindowConfig.GroupByKey)
if len(config2.WindowConfig.GroupByKeys) == 0 || config2.WindowConfig.GroupByKeys[0] != "user_id" {
t.Errorf("Expected GroupByKeys to contain 'user_id', got %v", config2.WindowConfig.GroupByKeys)
}
}

Expand Down
28 changes: 22 additions & 6 deletions rsql/coverage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/rulego/streamsql/aggregator"
"github.com/rulego/streamsql/types"
"github.com/rulego/streamsql/window"
)

// TestParseSmartParameters 测试智能参数解析函数
Expand Down Expand Up @@ -202,6 +203,12 @@ func TestParseWindowParams(t *testing.T) {
windowType: "SLIDINGWINDOW",
expectError: false,
},
{
name: "计数窗口参数",
params: []interface{}{100},
windowType: "COUNTINGWINDOW",
expectError: false,
},
{
name: "无效持续时间",
params: []interface{}{"invalid"},
Expand All @@ -212,7 +219,7 @@ func TestParseWindowParams(t *testing.T) {
name: "非字符串参数",
params: []interface{}{123},
windowType: "TUMBLINGWINDOW",
expectError: true,
expectError: false, // 整数参数会被视为秒数,这是有效的
},
{
name: "空参数",
Expand All @@ -224,15 +231,24 @@ func TestParseWindowParams(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var result map[string]interface{}
var result []interface{}
var err error

if tt.windowType == "SESSIONWINDOW" {
result, err = parseWindowParamsWithType(tt.params, "SESSIONWINDOW")
} else {
result, err = parseWindowParams(tt.params)
// Convert window type to internal format
windowType := ""
switch tt.windowType {
case "SESSIONWINDOW":
windowType = window.TypeSession
case "TUMBLINGWINDOW":
windowType = window.TypeTumbling
case "SLIDINGWINDOW":
windowType = window.TypeSliding
case "COUNTINGWINDOW":
windowType = window.TypeCounting
}

result, err = validateWindowParams(tt.params, windowType)

if tt.expectError {
if err == nil {
t.Errorf("Expected error but got none")
Expand Down
Loading
Loading