Skip to content

Commit fae6af3

Browse files
committed
feat: add fallback field mapping for domain struct params
This adds multiple fallback strategies to map source struct fields to target struct fields when field names differ between the wrapper API and the database adapter. The strategies include: exact match, case-insensitive match, snake_case match, and position-based match when structs have the same field count. This is needed because sqlc generates different field names for unnamed query parameters (e.g., LIMIT ? generates 'Limit' instead of 'BatchSize'). The fix enables the generator to correctly map all parameters regardless of naming differences.
1 parent 8981bbd commit fae6af3

7 files changed

Lines changed: 190 additions & 22 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ SELECT unnest(@book_ids::bigint[]), unnest(@tag_ids::bigint[]);
9090
```
9191

9292
The generator will:
93+
9394
- On **PostgreSQL**: delegate `AddBookTags` directly to the underlying sqlc implementation
9495
- On **SQLite/MySQL**: generate a loop that calls `AddBookTag` once per element
9596

example/pkg/database/database.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ import (
1010
"github.com/kalbasit/sqlc-multi-db/example/pkg/database/postgresdb"
1111
"github.com/kalbasit/sqlc-multi-db/example/pkg/database/sqlitedb"
1212

13-
_ "github.com/go-sql-driver/mysql" // MySQL driver
14-
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
15-
_ "github.com/mattn/go-sqlite3" // SQLite driver
13+
_ "github.com/go-sql-driver/mysql" // MySQL driver
14+
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
15+
_ "github.com/mattn/go-sqlite3" // SQLite driver
1616
)
1717

1818
// Open opens a database connection and returns a Querier.

example/pkg/database/errors.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@ import (
88
"github.com/mattn/go-sqlite3"
99
)
1010

11-
var (
12-
// ErrUnsupportedDriver is returned when the database driver is not recognized.
13-
ErrUnsupportedDriver = errors.New("unsupported database driver")
14-
)
11+
// ErrUnsupportedDriver is returned when the database driver is not recognized.
12+
var ErrUnsupportedDriver = errors.New("unsupported database driver")
1513

1614
// IsDeadlockError checks if the error is a deadlock or "database busy" error.
1715
func IsDeadlockError(err error) bool {

example/pkg/database/postgresdb/querier.go

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

generator/constants.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ const (
55
typeAny = "interface{}"
66
typeBool = "bool"
77
typeString = "string"
8+
typeInt = "int"
9+
typeBytes = "[]byte"
810
zeroNil = "nil"
911
typeInt16 = "int16"
1012
typeInt32 = "int32"

generator/generator_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,3 +553,80 @@ func TestGenerateFieldConversion(t *testing.T) {
553553
})
554554
}
555555
}
556+
557+
// TestJoinParamsCallFieldMapping tests that JoinParamsCall correctly maps
558+
// struct fields even when field names differ between source and target.
559+
// This is a regression test for the MySQL LIMIT parameter issue where
560+
// sqlc generates different field names (e.g., BatchSize vs Limit).
561+
func TestJoinParamsCallFieldMapping(t *testing.T) {
562+
t.Parallel()
563+
564+
// Source structs (domain) - what the wrapper API uses
565+
sourceStructs := map[string]generator.StructInfo{
566+
"GetStuckNarFilesParams": {
567+
Name: "GetStuckNarFilesParams",
568+
Fields: []generator.FieldInfo{
569+
{Name: "CutoffTime", Type: "time.Time"},
570+
{Name: "BatchSize", Type: "int32"},
571+
},
572+
},
573+
}
574+
575+
// Target structs (adapter) - what the database engine generates
576+
// MySQL generates different names: CreatedAt instead of CutoffTime, Limit instead of BatchSize
577+
targetStructs := map[string]generator.StructInfo{
578+
"GetStuckNarFilesParams": {
579+
Name: "GetStuckNarFilesParams",
580+
Fields: []generator.FieldInfo{
581+
{Name: "CreatedAt", Type: "time.Time"},
582+
{Name: "Limit", Type: "int32"},
583+
},
584+
},
585+
}
586+
587+
// Target method info
588+
targetMethod := generator.MethodInfo{
589+
Name: "GetStuckNarFiles",
590+
Params: []generator.Param{
591+
{Name: "ctx", Type: "context.Context"},
592+
{Name: "arg", Type: "GetStuckNarFilesParams"},
593+
},
594+
}
595+
596+
tests := []struct {
597+
name string
598+
params []generator.Param
599+
engPkg string
600+
want string
601+
wantErr bool
602+
}{
603+
{
604+
name: "Field mapping with different names",
605+
params: []generator.Param{
606+
{Name: "ctx", Type: "context.Context"},
607+
{Name: "arg", Type: "GetStuckNarFilesParams"},
608+
},
609+
engPkg: "mysqldb",
610+
// Expected: both fields should be mapped even though names differ
611+
// The target struct uses its own field names (CreatedAt, Limit), not source names
612+
want: "ctx, mysqldb.GetStuckNarFilesParams{\nCreatedAt: arg.CutoffTime,\nLimit: arg.BatchSize,\n}",
613+
},
614+
}
615+
616+
for _, tt := range tests {
617+
t.Run(tt.name, func(t *testing.T) {
618+
t.Parallel()
619+
620+
got, err := generator.JoinParamsCall(tt.params, tt.engPkg, targetMethod, targetStructs, sourceStructs)
621+
if (err != nil) != tt.wantErr {
622+
t.Errorf("JoinParamsCall() error = %v, wantErr %v", err, tt.wantErr)
623+
624+
return
625+
}
626+
627+
if got != tt.want {
628+
t.Errorf("JoinParamsCall() = %v, want %v", got, tt.want)
629+
}
630+
})
631+
}
632+
}

generator/helpers.go

Lines changed: 105 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,98 @@ func JoinParamsCall(
134134
return joinParamsCall(params, engPkg, targetMethod, targetStructs, sourceStructs)
135135
}
136136

137+
// findSourceField finds a matching field in source struct using multiple strategies:
138+
// 1. Exact name match
139+
// 2. Case-insensitive match
140+
// 3. Snake_case match
141+
// 4. Position-based match (fallback when structs have same field count).
142+
func findSourceField(
143+
targetField FieldInfo,
144+
targetIdx int,
145+
targetStruct StructInfo,
146+
sourceStruct StructInfo,
147+
) (FieldInfo, bool) {
148+
// Strategy 1: Exact name match
149+
for _, sf := range sourceStruct.Fields {
150+
if sf.Name == targetField.Name {
151+
return sf, true
152+
}
153+
}
154+
155+
// Strategy 2: Case-insensitive match
156+
for _, sf := range sourceStruct.Fields {
157+
if strings.EqualFold(sf.Name, targetField.Name) {
158+
return sf, true
159+
}
160+
}
161+
162+
// Strategy 3: Snake_case match
163+
targetSnake := toSnakeCase(targetField.Name)
164+
for _, sf := range sourceStruct.Fields {
165+
if toSnakeCase(sf.Name) == targetSnake {
166+
return sf, true
167+
}
168+
}
169+
170+
// Strategy 4: Position-based match (fallback when structs have same field count)
171+
// Only use position matching if the structs have the same number of fields
172+
if len(sourceStruct.Fields) == len(targetStruct.Fields) && len(sourceStruct.Fields) > 0 {
173+
// Match by position - use the field at the same index in source
174+
if targetIdx < len(sourceStruct.Fields) {
175+
sf := sourceStruct.Fields[targetIdx]
176+
// Verify types are compatible
177+
if fieldsCompatible(sf.Type, targetField.Type) {
178+
return sf, true
179+
}
180+
}
181+
}
182+
183+
return FieldInfo{}, false
184+
}
185+
186+
// fieldsCompatible checks if two field types are compatible for mapping.
187+
func fieldsCompatible(sourceType, targetType string) bool {
188+
// Normalize types for comparison
189+
sourceBase := normalizeType(sourceType)
190+
targetBase := normalizeType(targetType)
191+
192+
return sourceBase == targetBase
193+
}
194+
195+
// normalizeType normalizes a type string for comparison.
196+
func normalizeType(t string) string {
197+
// Remove common prefixes/suffixes
198+
t = strings.TrimPrefix(t, "[]")
199+
t = strings.TrimPrefix(t, "*")
200+
201+
// Handle time types
202+
if strings.Contains(t, "time.Time") || strings.Contains(t, "NullTime") {
203+
return "time"
204+
}
205+
206+
// Handle numeric types
207+
switch t {
208+
case typeInt, "int8", "int16", "int32", "int64",
209+
"uint", "uint8", "uint16", "uint32", "uint64",
210+
sqlNullInt32, sqlNullInt64:
211+
return typeInt
212+
case "float32", "float64", sqlNullFloat64:
213+
return "float"
214+
case typeString, sqlNullString, typeBytes:
215+
return typeString
216+
case typeBool, sqlNullBool:
217+
return typeBool
218+
}
219+
220+
// Remove package prefix if present
221+
parts := strings.Split(t, ".")
222+
if len(parts) > 1 {
223+
return parts[len(parts)-1]
224+
}
225+
226+
return t
227+
}
228+
137229
func joinDomainStructParam(
138230
param Param,
139231
i int,
@@ -153,23 +245,22 @@ func joinDomainStructParam(
153245

154246
if targetParamType != "" {
155247
sourceStruct := sourceStructs[param.Type]
156-
targetStruct := targetStructs[targetParamType]
157-
158-
var fields []string
159-
160-
for _, targetField := range targetStruct.Fields {
161-
var sourceField FieldInfo
248+
// Target struct keys may include the package prefix (e.g., "mysqldb.GetStuckNarFilesParams")
249+
// Try with prefix first, then without
250+
targetStructKey := targetParamType
251+
if engPkg != "" {
252+
if _, ok := targetStructs[engPkg+"."+targetParamType]; ok {
253+
targetStructKey = engPkg + "." + targetParamType
254+
}
255+
// Otherwise keep using targetParamType (no prefix)
256+
}
162257

163-
found := false
258+
targetStruct := targetStructs[targetStructKey]
164259

165-
for _, sf := range sourceStruct.Fields {
166-
if sf.Name == targetField.Name {
167-
sourceField = sf
168-
found = true
260+
var fields []string
169261

170-
break
171-
}
172-
}
262+
for targetIdx, targetField := range targetStruct.Fields {
263+
sourceField, found := findSourceField(targetField, targetIdx, targetStruct, sourceStruct)
173264

174265
if found {
175266
conversion := generateFieldConversion(

0 commit comments

Comments
 (0)