Skip to content

Commit 1e1282a

Browse files
committed
feat:json_extract支持map 和 array入参
1 parent 52653b2 commit 1e1282a

4 files changed

Lines changed: 258 additions & 32 deletions

File tree

docs/FUNCTIONS_USAGE_GUIDE.md

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,30 @@ JSON函数用于处理JSON数据。
530530
**描述**: 从JSON字符串解析值。
531531

532532
### JSON_EXTRACT - JSON提取函数
533-
**语法**: `json_extract(json_str, path)`
534-
**描述**: 从JSON字符串中提取指定路径的值。
533+
**语法**: `json_extract(json_source, path)`
534+
**描述**: 从JSON字符串、Map或Array中提取指定路径的值。支持嵌套对象和数组索引。
535+
536+
**参数**:
537+
- `json_source`: 输入数据,可以是JSON格式字符串,也可以是Map或Array类型对象
538+
- `path`: 提取路径,支持 `.` 访问字段,`[]` 访问数组索引或Map Key
539+
540+
**示例**:
541+
```sql
542+
-- 提取基本字段
543+
json_extract('{"name": "Alice"}', 'name') -- 返回 "Alice"
544+
json_extract('{"name": "Alice"}', '$.name') -- 返回 "Alice"
545+
546+
-- 提取嵌套字段
547+
json_extract('{"user": {"address": {"city": "New York"}}}', 'user.address.city') -- 返回 "New York"
548+
json_extract('{"user": {"address": {"city": "New York"}}}', '$.user.address.city') -- 返回 "New York"
549+
550+
-- 提取数组元素
551+
json_extract('[10, 20, 30]', '[1]') -- 返回 20
552+
json_extract('[10, 20, 30]', '$[1]') -- 返回 20
553+
554+
-- 复杂嵌套提取
555+
json_extract('{"users": [{"name": "Alice"}, {"name": "Bob"}]}', 'users[1].name') -- 返回 "Bob"
556+
```
535557

536558
### JSON_VALID - JSON验证函数
537559
**语法**: `json_valid(json_str)`

functions/functions_json.go

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"encoding/json"
55
"fmt"
66
"strings"
7+
8+
"github.com/rulego/streamsql/utils/fieldpath"
79
)
810

911
// ToJsonFunction converts value to JSON string
@@ -75,31 +77,51 @@ func (f *JsonExtractFunction) Validate(args []interface{}) error {
7577
}
7678

7779
func (f *JsonExtractFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
78-
jsonStr, ok := args[0].(string)
79-
if !ok {
80-
return nil, fmt.Errorf("json_extract requires string input")
80+
var data interface{}
81+
var err error
82+
83+
// Support string (JSON), map, and slice input
84+
switch v := args[0].(type) {
85+
case string:
86+
err = json.Unmarshal([]byte(v), &data)
87+
if err != nil {
88+
return nil, fmt.Errorf("failed to parse JSON: %v", err)
89+
}
90+
case map[string]interface{}:
91+
data = v
92+
case []interface{}:
93+
data = v
94+
default:
95+
return nil, fmt.Errorf("json_extract requires string, map, or array input")
8196
}
8297

8398
path, ok := args[1].(string)
8499
if !ok {
85100
return nil, fmt.Errorf("json_extract path must be string")
86101
}
87102

88-
var data interface{}
89-
err := json.Unmarshal([]byte(jsonStr), &data)
90-
if err != nil {
91-
return nil, fmt.Errorf("failed to parse JSON: %v", err)
103+
// Handle JSON Path format
104+
// If path starts with $, strip it
105+
if strings.HasPrefix(path, "$") {
106+
path = path[1:]
107+
}
108+
// If path starts with ., strip it (unless it's empty, though GetNestedField handles empty path)
109+
if strings.HasPrefix(path, ".") {
110+
path = path[1:]
92111
}
93112

94-
// Simple path extraction, supports $.field format
95-
if strings.HasPrefix(path, "$.") {
96-
field := path[2:]
97-
if dataMap, ok := data.(map[string]interface{}); ok {
98-
return dataMap[field], nil
99-
}
113+
// Use fieldpath utility to extract value
114+
val, found := fieldpath.GetNestedField(data, path)
115+
if !found {
116+
// Try to see if it's a simple key access that fieldpath might have missed or if data structure is simple
117+
// But fieldpath covers most cases. If not found, it means the path doesn't exist.
118+
// However, for compatibility with previous simple implementation:
119+
// Previous implementation returned nil if key not found (implicit in map lookup)
120+
// So returning nil, nil is correct behavior when field is missing
121+
return nil, nil
100122
}
101123

102-
return nil, fmt.Errorf("invalid JSON path or data structure")
124+
return val, nil
103125
}
104126

105127
// JsonValidFunction 验证JSON格式是否有效

functions/functions_json_test.go

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,13 @@ func TestJsonFunctions(t *testing.T) {
8989
name: "json_extract invalid path",
9090
funcName: "json_extract",
9191
args: []interface{}{`{"name":"test"}`, "invalid_path"},
92-
wantErr: true,
92+
expected: nil,
9393
},
9494
{
9595
name: "json_extract non-object",
9696
funcName: "json_extract",
9797
args: []interface{}{`[1,2,3]`, "$.name"},
98-
wantErr: true,
98+
expected: nil,
9999
},
100100
{
101101
name: "json_valid true",
@@ -205,6 +205,42 @@ func TestJsonFunctions(t *testing.T) {
205205
args: []interface{}{`"hello"`},
206206
wantErr: true,
207207
},
208+
{
209+
name: "json_extract map input",
210+
funcName: "json_extract",
211+
args: []interface{}{map[string]interface{}{"name": "test", "value": 123}, "$.name"},
212+
expected: "test",
213+
},
214+
{
215+
name: "json_extract map input number",
216+
funcName: "json_extract",
217+
args: []interface{}{map[string]interface{}{"name": "test", "value": 123}, "$.value"},
218+
expected: 123,
219+
},
220+
{
221+
name: "json_extract array input",
222+
funcName: "json_extract",
223+
args: []interface{}{[]interface{}{10, 20, 30}, "$[1]"},
224+
expected: 20,
225+
},
226+
{
227+
name: "json_extract nested map",
228+
funcName: "json_extract",
229+
args: []interface{}{`{"a": {"b": 100}}`, "$.a.b"},
230+
expected: float64(100),
231+
},
232+
{
233+
name: "json_extract nested array",
234+
funcName: "json_extract",
235+
args: []interface{}{`{"list": [1, 2, 3]}`, "$.list[2]"},
236+
expected: float64(3),
237+
},
238+
{
239+
name: "json_extract complex path",
240+
funcName: "json_extract",
241+
args: []interface{}{map[string]interface{}{"users": []interface{}{map[string]interface{}{"id": 1}, map[string]interface{}{"id": 2}}}, "$.users[1].id"},
242+
expected: 2,
243+
},
208244
}
209245

210246
for _, tt := range tests {
@@ -309,38 +345,38 @@ func TestJsonFunctionValidation(t *testing.T) {
309345
// TestJsonFunctionCreation 测试JSON函数创建
310346
func TestJsonFunctionCreation(t *testing.T) {
311347
tests := []struct {
312-
name string
313-
constructor func() Function
348+
name string
349+
constructor func() Function
314350
expectedName string
315351
}{
316352
{
317-
name: "ToJsonFunction",
318-
constructor: func() Function { return NewToJsonFunction() },
353+
name: "ToJsonFunction",
354+
constructor: func() Function { return NewToJsonFunction() },
319355
expectedName: "to_json",
320356
},
321357
{
322-
name: "FromJsonFunction",
323-
constructor: func() Function { return NewFromJsonFunction() },
358+
name: "FromJsonFunction",
359+
constructor: func() Function { return NewFromJsonFunction() },
324360
expectedName: "from_json",
325361
},
326362
{
327-
name: "JsonExtractFunction",
328-
constructor: func() Function { return NewJsonExtractFunction() },
363+
name: "JsonExtractFunction",
364+
constructor: func() Function { return NewJsonExtractFunction() },
329365
expectedName: "json_extract",
330366
},
331367
{
332-
name: "JsonValidFunction",
333-
constructor: func() Function { return NewJsonValidFunction() },
368+
name: "JsonValidFunction",
369+
constructor: func() Function { return NewJsonValidFunction() },
334370
expectedName: "json_valid",
335371
},
336372
{
337-
name: "JsonTypeFunction",
338-
constructor: func() Function { return NewJsonTypeFunction() },
373+
name: "JsonTypeFunction",
374+
constructor: func() Function { return NewJsonTypeFunction() },
339375
expectedName: "json_type",
340376
},
341377
{
342-
name: "JsonLengthFunction",
343-
constructor: func() Function { return NewJsonLengthFunction() },
378+
name: "JsonLengthFunction",
379+
constructor: func() Function { return NewJsonLengthFunction() },
344380
expectedName: "json_length",
345381
},
346382
}

streamsql_function_integration_test.go

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,152 @@ func TestFunctionIntegrationNonAggregation(t *testing.T) {
225225
t.Fatal("测试超时,未收到结果")
226226
}
227227
})
228+
229+
t.Run("JSONExtractMapSupport", func(t *testing.T) {
230+
streamsql := New()
231+
defer streamsql.Stop()
232+
233+
// Test json_extract with map input
234+
rsql := "SELECT device, json_extract(properties, '$.color') as device_color FROM stream"
235+
err := streamsql.Execute(rsql)
236+
assert.Nil(t, err)
237+
238+
strm := streamsql.stream
239+
resultChan := make(chan interface{}, 10)
240+
strm.AddSink(func(result []map[string]interface{}) {
241+
resultChan <- result
242+
})
243+
244+
// Add test data with map
245+
testData := map[string]interface{}{
246+
"device": "test-device-map",
247+
"properties": map[string]interface{}{
248+
"color": "red",
249+
"weight": 10,
250+
},
251+
}
252+
strm.Emit(testData)
253+
254+
// Wait for result
255+
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
256+
defer cancel()
257+
258+
select {
259+
case result := <-resultChan:
260+
resultSlice, ok := result.([]map[string]interface{})
261+
require.True(t, ok)
262+
require.Len(t, resultSlice, 1)
263+
264+
item := resultSlice[0]
265+
assert.Equal(t, "test-device-map", item["device"])
266+
assert.Equal(t, "red", item["device_color"])
267+
case <-ctx.Done():
268+
t.Fatal("测试超时,未收到结果")
269+
}
270+
})
271+
272+
t.Run("JSONExtractArrayAndNested", func(t *testing.T) {
273+
streamsql := New()
274+
defer streamsql.Stop()
275+
276+
// Test json_extract with array and nested structures
277+
rsql := "SELECT device, json_extract(tags, '$[0]') as first_tag, json_extract(data, '$.users[0].name') as first_user_name FROM stream"
278+
err := streamsql.Execute(rsql)
279+
assert.Nil(t, err)
280+
281+
strm := streamsql.stream
282+
resultChan := make(chan interface{}, 10)
283+
strm.AddSink(func(result []map[string]interface{}) {
284+
resultChan <- result
285+
})
286+
287+
// Add test data with complex structures
288+
testData := map[string]interface{}{
289+
"device": "complex-device",
290+
"tags": []interface{}{"tag1", "tag2"},
291+
"data": map[string]interface{}{
292+
"users": []interface{}{
293+
map[string]interface{}{"name": "Alice", "age": 30},
294+
map[string]interface{}{"name": "Bob", "age": 25},
295+
},
296+
},
297+
}
298+
strm.Emit(testData)
299+
300+
// Wait for result
301+
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
302+
defer cancel()
303+
304+
select {
305+
case result := <-resultChan:
306+
resultSlice, ok := result.([]map[string]interface{})
307+
require.True(t, ok)
308+
require.Len(t, resultSlice, 1)
309+
310+
item := resultSlice[0]
311+
assert.Equal(t, "complex-device", item["device"])
312+
assert.Equal(t, "tag1", item["first_tag"])
313+
assert.Equal(t, "Alice", item["first_user_name"])
314+
case <-ctx.Done():
315+
t.Fatal("测试超时,未收到结果")
316+
}
317+
})
318+
319+
t.Run("JSONExtractWithAggregation", func(t *testing.T) {
320+
streamsql := New()
321+
defer streamsql.Stop()
322+
323+
// Test json_extract nested in aggregation function
324+
// json_extract returns interface{}, usually need cast to number for aggregation like sum/avg
325+
// specific logic depends on whether aggregator handles string/interface conversion
326+
// Here we assume json_extract returns float64 for numbers (from Unmarshal) or use cast
327+
rsql := "SELECT count(json_extract(tags, '$[0]')) as tag_count, sum(cast(json_extract(data, '$.value'), 'float')) as total_value FROM stream GROUP BY device, TumblingWindow('1s')"
328+
err := streamsql.Execute(rsql)
329+
assert.Nil(t, err)
330+
331+
strm := streamsql.stream
332+
resultChan := make(chan interface{}, 10)
333+
strm.AddSink(func(result []map[string]interface{}) {
334+
resultChan <- result
335+
})
336+
337+
testData := []map[string]interface{}{
338+
{
339+
"device": "device1",
340+
"tags": []interface{}{"tag1", "tag2"},
341+
"data": map[string]interface{}{"value": 10},
342+
},
343+
{
344+
"device": "device1",
345+
"tags": []interface{}{"tag3"},
346+
"data": map[string]interface{}{"value": 20},
347+
},
348+
}
349+
350+
for _, data := range testData {
351+
strm.Emit(data)
352+
}
353+
354+
time.Sleep(1 * time.Second)
355+
strm.Window.Trigger()
356+
357+
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
358+
defer cancel()
359+
360+
select {
361+
case result := <-resultChan:
362+
resultSlice, ok := result.([]map[string]interface{})
363+
require.True(t, ok)
364+
require.Len(t, resultSlice, 1)
365+
366+
item := resultSlice[0]
367+
assert.Equal(t, "device1", item["device"])
368+
assert.Equal(t, float64(2), item["tag_count"])
369+
assert.Equal(t, float64(30), item["total_value"])
370+
case <-ctx.Done():
371+
t.Fatal("测试超时,未收到结果")
372+
}
373+
})
228374
}
229375

230376
// TestFunctionIntegrationAggregation 测试聚合函数在SQL中的集成

0 commit comments

Comments
 (0)