Skip to content

Commit 20ba54b

Browse files
authored
Merge pull request #58 from shamaton/embedded
Add embedded struct support with optimized fast path
2 parents f71f5e2 + 984f35f commit 20ba54b

6 files changed

Lines changed: 1607 additions & 181 deletions

File tree

internal/common/common.go

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,160 @@ import (
99
type Common struct {
1010
}
1111

12+
// FieldInfo holds information about a struct field including its path for embedded structs
13+
type FieldInfo struct {
14+
Path []int // path to reach this field (indices for embedded structs)
15+
Name string // field name or tag
16+
Omit bool // omitempty flag
17+
Tagged bool // tag name explicitly set
18+
OmitPaths [][]int // paths to embedded fields with omitempty
19+
}
20+
21+
// CollectFields collects all fields from a struct, expanding embedded structs
22+
// following the same rules as encoding/json
23+
func (c *Common) CollectFields(t reflect.Type, path []int) []FieldInfo {
24+
return c.collectFields(t, path, nil)
25+
}
26+
27+
func (c *Common) collectFields(t reflect.Type, path []int, omitPaths [][]int) []FieldInfo {
28+
var fields []FieldInfo
29+
var embedded []FieldInfo // embedded fields to process later (lower priority)
30+
31+
for i := 0; i < t.NumField(); i++ {
32+
field := t.Field(i)
33+
34+
// Check field visibility and get omitempty flag
35+
public, omit, name := c.CheckField(field)
36+
if !public {
37+
continue
38+
}
39+
40+
// Get tag to check if embedded
41+
tag := field.Tag.Get("msgpack")
42+
// Extract just the name part (before comma if any)
43+
tagName := tag
44+
for j, ch := range tag {
45+
if ch == ',' {
46+
tagName = tag[:j]
47+
break
48+
}
49+
}
50+
51+
// Check if this is an embedded struct
52+
isEmbedded := field.Anonymous && (tag == "" || tagName == "")
53+
tagged := tagName != ""
54+
55+
if isEmbedded {
56+
// Get the actual type (dereference pointer if needed)
57+
fieldType := field.Type
58+
if fieldType.Kind() == reflect.Ptr {
59+
fieldType = fieldType.Elem()
60+
}
61+
62+
// If it's a struct, expand its fields
63+
if fieldType.Kind() == reflect.Struct {
64+
newPath := append(append([]int{}, path...), i)
65+
nextOmitPaths := omitPaths
66+
if omit {
67+
nextOmitPaths = appendOmitPath(omitPaths, newPath)
68+
}
69+
embeddedFields := c.collectFields(fieldType, newPath, nextOmitPaths)
70+
embedded = append(embedded, embeddedFields...)
71+
continue
72+
}
73+
}
74+
75+
// Regular field or embedded non-struct
76+
newPath := append(append([]int{}, path...), i)
77+
fields = append(fields, FieldInfo{
78+
Path: newPath,
79+
Name: name,
80+
Omit: omit,
81+
Tagged: tagged,
82+
OmitPaths: omitPaths,
83+
})
84+
}
85+
86+
// Add embedded fields after regular fields (they have lower priority)
87+
fields = append(fields, embedded...)
88+
89+
// Remove duplicates and handle ambiguous fields
90+
return c.deduplicateFields(fields)
91+
}
92+
93+
func appendOmitPath(paths [][]int, path []int) [][]int {
94+
if len(paths) == 0 {
95+
return [][]int{path}
96+
}
97+
newPaths := make([][]int, len(paths)+1)
98+
copy(newPaths, paths)
99+
newPaths[len(paths)] = path
100+
return newPaths
101+
}
102+
103+
// deduplicateFields removes duplicate fields and handles ambiguous fields
104+
// following encoding/json behavior
105+
func (c *Common) deduplicateFields(fields []FieldInfo) []FieldInfo {
106+
// Group fields by name and depth, preserving order
107+
type fieldAtDepth struct {
108+
field FieldInfo
109+
depth int
110+
}
111+
fieldsByName := make(map[string][]fieldAtDepth)
112+
var seenNames []string // To preserve order
113+
114+
for _, f := range fields {
115+
if _, seen := fieldsByName[f.Name]; !seen {
116+
seenNames = append(seenNames, f.Name)
117+
}
118+
fieldsByName[f.Name] = append(fieldsByName[f.Name], fieldAtDepth{
119+
field: f,
120+
depth: len(f.Path),
121+
})
122+
}
123+
124+
var result []FieldInfo
125+
for _, name := range seenNames {
126+
fieldsWithDepth := fieldsByName[name]
127+
128+
// Find minimum depth
129+
minDepth := fieldsWithDepth[0].depth
130+
for _, fd := range fieldsWithDepth {
131+
if fd.depth < minDepth {
132+
minDepth = fd.depth
133+
}
134+
}
135+
136+
// Count fields at minimum depth
137+
var fieldsAtMinDepth []FieldInfo
138+
for _, fd := range fieldsWithDepth {
139+
if fd.depth == minDepth {
140+
fieldsAtMinDepth = append(fieldsAtMinDepth, fd.field)
141+
}
142+
}
143+
144+
// If there's exactly one field at minimum depth, use it
145+
if len(fieldsAtMinDepth) == 1 {
146+
result = append(result, fieldsAtMinDepth[0])
147+
continue
148+
}
149+
150+
// Prefer the tagged field if exactly one is tagged at minimum depth
151+
var taggedFields []FieldInfo
152+
for _, f := range fieldsAtMinDepth {
153+
if f.Tagged {
154+
taggedFields = append(taggedFields, f)
155+
}
156+
}
157+
if len(taggedFields) == 1 {
158+
result = append(result, taggedFields[0])
159+
}
160+
// else: ambiguous field, skip it (following encoding/json behavior)
161+
}
162+
163+
return result
164+
}
165+
12166
// CheckField returns flag whether should encode/decode or not and field name
13167
func (c *Common) CheckField(field reflect.StructField) (public, omit bool, name string) {
14168
// A to Z

0 commit comments

Comments
 (0)