From dc5df7cc1a013c3a40f7692f57b640b974ddb306 Mon Sep 17 00:00:00 2001 From: Wael Nasreddine Date: Sun, 15 Mar 2026 16:00:01 -0700 Subject: [PATCH 1/2] feat: add fallback field mapping for domain struct params This adds multiple fallback strategies to map source struct fields to target struct fields when field names differ between the wrapper API and the database adapter. The strategies include: exact match, case-insensitive match, snake_case match, and position-based match when structs have the same field count. This is needed because sqlc generates different field names for unnamed query parameters (e.g., LIMIT ? generates 'Limit' instead of 'BatchSize'). The fix enables the generator to correctly map all parameters regardless of naming differences. --- README.md | 1 + example/pkg/database/database.go | 6 +- example/pkg/database/errors.go | 6 +- example/pkg/database/postgresdb/querier.go | 1 - generator/constants.go | 2 + generator/generator_test.go | 77 +++++++++++++ generator/helpers.go | 119 ++++++++++++++++++--- 7 files changed, 190 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index ceb734e..7c2a769 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,7 @@ SELECT unnest(@book_ids::bigint[]), unnest(@tag_ids::bigint[]); ``` The generator will: + - On **PostgreSQL**: delegate `AddBookTags` directly to the underlying sqlc implementation - On **SQLite/MySQL**: generate a loop that calls `AddBookTag` once per element diff --git a/example/pkg/database/database.go b/example/pkg/database/database.go index 2b6495a..41899a2 100644 --- a/example/pkg/database/database.go +++ b/example/pkg/database/database.go @@ -10,9 +10,9 @@ import ( "github.com/kalbasit/sqlc-multi-db/example/pkg/database/postgresdb" "github.com/kalbasit/sqlc-multi-db/example/pkg/database/sqlitedb" - _ "github.com/go-sql-driver/mysql" // MySQL driver - _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver - _ "github.com/mattn/go-sqlite3" // SQLite driver + _ "github.com/go-sql-driver/mysql" // MySQL driver + _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver + _ "github.com/mattn/go-sqlite3" // SQLite driver ) // Open opens a database connection and returns a Querier. diff --git a/example/pkg/database/errors.go b/example/pkg/database/errors.go index 7ab6320..42bde8f 100644 --- a/example/pkg/database/errors.go +++ b/example/pkg/database/errors.go @@ -8,10 +8,8 @@ import ( "github.com/mattn/go-sqlite3" ) -var ( - // ErrUnsupportedDriver is returned when the database driver is not recognized. - ErrUnsupportedDriver = errors.New("unsupported database driver") -) +// ErrUnsupportedDriver is returned when the database driver is not recognized. +var ErrUnsupportedDriver = errors.New("unsupported database driver") // IsDeadlockError checks if the error is a deadlock or "database busy" error. func IsDeadlockError(err error) bool { diff --git a/example/pkg/database/postgresdb/querier.go b/example/pkg/database/postgresdb/querier.go index 4536ae2..be00bae 100644 --- a/example/pkg/database/postgresdb/querier.go +++ b/example/pkg/database/postgresdb/querier.go @@ -16,7 +16,6 @@ type Querier interface { // @bulk-for AddBookTag AddBookTags(ctx context.Context, arg AddBookTagsParams) error - // CreateBook creates a new book. // INSERT INTO books ("title", "author", "description") VALUES ($1, $2, $3) RETURNING "id", "title", "author", "description", "created_at", "updated_at" CreateBook(ctx context.Context, arg CreateBookParams) (Book, error) diff --git a/generator/constants.go b/generator/constants.go index 00599ba..c56e17c 100644 --- a/generator/constants.go +++ b/generator/constants.go @@ -5,6 +5,8 @@ const ( typeAny = "interface{}" typeBool = "bool" typeString = "string" + typeInt = "int" + typeBytes = "[]byte" zeroNil = "nil" typeInt16 = "int16" typeInt32 = "int32" diff --git a/generator/generator_test.go b/generator/generator_test.go index dd8cb45..007e666 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -553,3 +553,80 @@ func TestGenerateFieldConversion(t *testing.T) { }) } } + +// TestJoinParamsCallFieldMapping tests that JoinParamsCall correctly maps +// struct fields even when field names differ between source and target. +// This is a regression test for the MySQL LIMIT parameter issue where +// sqlc generates different field names (e.g., BatchSize vs Limit). +func TestJoinParamsCallFieldMapping(t *testing.T) { + t.Parallel() + + // Source structs (domain) - what the wrapper API uses + sourceStructs := map[string]generator.StructInfo{ + "GetStuckNarFilesParams": { + Name: "GetStuckNarFilesParams", + Fields: []generator.FieldInfo{ + {Name: "CutoffTime", Type: "time.Time"}, + {Name: "BatchSize", Type: "int32"}, + }, + }, + } + + // Target structs (adapter) - what the database engine generates + // MySQL generates different names: CreatedAt instead of CutoffTime, Limit instead of BatchSize + targetStructs := map[string]generator.StructInfo{ + "GetStuckNarFilesParams": { + Name: "GetStuckNarFilesParams", + Fields: []generator.FieldInfo{ + {Name: "CreatedAt", Type: "time.Time"}, + {Name: "Limit", Type: "int32"}, + }, + }, + } + + // Target method info + targetMethod := generator.MethodInfo{ + Name: "GetStuckNarFiles", + Params: []generator.Param{ + {Name: "ctx", Type: "context.Context"}, + {Name: "arg", Type: "GetStuckNarFilesParams"}, + }, + } + + tests := []struct { + name string + params []generator.Param + engPkg string + want string + wantErr bool + }{ + { + name: "Field mapping with different names", + params: []generator.Param{ + {Name: "ctx", Type: "context.Context"}, + {Name: "arg", Type: "GetStuckNarFilesParams"}, + }, + engPkg: "mysqldb", + // Expected: both fields should be mapped even though names differ + // The target struct uses its own field names (CreatedAt, Limit), not source names + want: "ctx, mysqldb.GetStuckNarFilesParams{\nCreatedAt: arg.CutoffTime,\nLimit: arg.BatchSize,\n}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := generator.JoinParamsCall(tt.params, tt.engPkg, targetMethod, targetStructs, sourceStructs) + if (err != nil) != tt.wantErr { + t.Errorf("JoinParamsCall() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if got != tt.want { + t.Errorf("JoinParamsCall() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/generator/helpers.go b/generator/helpers.go index 84582a6..18b7a14 100644 --- a/generator/helpers.go +++ b/generator/helpers.go @@ -134,6 +134,98 @@ func JoinParamsCall( return joinParamsCall(params, engPkg, targetMethod, targetStructs, sourceStructs) } +// findSourceField finds a matching field in source struct using multiple strategies: +// 1. Exact name match +// 2. Case-insensitive match +// 3. Snake_case match +// 4. Position-based match (fallback when structs have same field count). +func findSourceField( + targetField FieldInfo, + targetIdx int, + targetStruct StructInfo, + sourceStruct StructInfo, +) (FieldInfo, bool) { + // Strategy 1: Exact name match + for _, sf := range sourceStruct.Fields { + if sf.Name == targetField.Name { + return sf, true + } + } + + // Strategy 2: Case-insensitive match + for _, sf := range sourceStruct.Fields { + if strings.EqualFold(sf.Name, targetField.Name) { + return sf, true + } + } + + // Strategy 3: Snake_case match + targetSnake := toSnakeCase(targetField.Name) + for _, sf := range sourceStruct.Fields { + if toSnakeCase(sf.Name) == targetSnake { + return sf, true + } + } + + // Strategy 4: Position-based match (fallback when structs have same field count) + // Only use position matching if the structs have the same number of fields + if len(sourceStruct.Fields) == len(targetStruct.Fields) && len(sourceStruct.Fields) > 0 { + // Match by position - use the field at the same index in source + if targetIdx < len(sourceStruct.Fields) { + sf := sourceStruct.Fields[targetIdx] + // Verify types are compatible + if fieldsCompatible(sf.Type, targetField.Type) { + return sf, true + } + } + } + + return FieldInfo{}, false +} + +// fieldsCompatible checks if two field types are compatible for mapping. +func fieldsCompatible(sourceType, targetType string) bool { + // Normalize types for comparison + sourceBase := normalizeType(sourceType) + targetBase := normalizeType(targetType) + + return sourceBase == targetBase +} + +// normalizeType normalizes a type string for comparison. +func normalizeType(t string) string { + // Remove common prefixes/suffixes + t = strings.TrimPrefix(t, "[]") + t = strings.TrimPrefix(t, "*") + + // Handle time types + if strings.Contains(t, "time.Time") || strings.Contains(t, "NullTime") { + return "time" + } + + // Handle numeric types + switch t { + case typeInt, "int8", "int16", "int32", "int64", + "uint", "uint8", "uint16", "uint32", "uint64", + sqlNullInt32, sqlNullInt64: + return typeInt + case "float32", "float64", sqlNullFloat64: + return "float" + case typeString, sqlNullString, typeBytes: + return typeString + case typeBool, sqlNullBool: + return typeBool + } + + // Remove package prefix if present + parts := strings.Split(t, ".") + if len(parts) > 1 { + return parts[len(parts)-1] + } + + return t +} + func joinDomainStructParam( param Param, i int, @@ -153,23 +245,22 @@ func joinDomainStructParam( if targetParamType != "" { sourceStruct := sourceStructs[param.Type] - targetStruct := targetStructs[targetParamType] - - var fields []string - - for _, targetField := range targetStruct.Fields { - var sourceField FieldInfo + // Target struct keys may include the package prefix (e.g., "mysqldb.GetStuckNarFilesParams") + // Try with prefix first, then without + targetStructKey := targetParamType + if engPkg != "" { + if _, ok := targetStructs[engPkg+"."+targetParamType]; ok { + targetStructKey = engPkg + "." + targetParamType + } + // Otherwise keep using targetParamType (no prefix) + } - found := false + targetStruct := targetStructs[targetStructKey] - for _, sf := range sourceStruct.Fields { - if sf.Name == targetField.Name { - sourceField = sf - found = true + var fields []string - break - } - } + for targetIdx, targetField := range targetStruct.Fields { + sourceField, found := findSourceField(targetField, targetIdx, targetStruct, sourceStruct) if found { conversion := generateFieldConversion( From 8beb0865f8b1075f3ffa549cde990d3563a49da4 Mon Sep 17 00:00:00 2001 From: Wael Nasreddine Date: Sun, 15 Mar 2026 16:06:43 -0700 Subject: [PATCH 2/2] fix: make findSourceField stateful to prevent duplicate field mapping The mapping process now tracks available source fields to prevent one source field from being mapped to multiple target fields. After a field is matched, it's removed from the available pool. Co-Authored-By: Claude Opus 4.6 --- generator/helpers.go | 50 +++++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/generator/helpers.go b/generator/helpers.go index 18b7a14..59dc104 100644 --- a/generator/helpers.go +++ b/generator/helpers.go @@ -134,26 +134,26 @@ func JoinParamsCall( return joinParamsCall(params, engPkg, targetMethod, targetStructs, sourceStructs) } -// findSourceField finds a matching field in source struct using multiple strategies: +// findSourceField finds a matching field in available source fields using multiple strategies: // 1. Exact name match // 2. Case-insensitive match // 3. Snake_case match // 4. Position-based match (fallback when structs have same field count). +// The availableSourceFields map is modified to remove matched fields. func findSourceField( targetField FieldInfo, targetIdx int, targetStruct StructInfo, sourceStruct StructInfo, + availableSourceFields map[string]FieldInfo, ) (FieldInfo, bool) { // Strategy 1: Exact name match - for _, sf := range sourceStruct.Fields { - if sf.Name == targetField.Name { - return sf, true - } + if sf, ok := availableSourceFields[targetField.Name]; ok { + return sf, true } // Strategy 2: Case-insensitive match - for _, sf := range sourceStruct.Fields { + for _, sf := range availableSourceFields { if strings.EqualFold(sf.Name, targetField.Name) { return sf, true } @@ -161,7 +161,7 @@ func findSourceField( // Strategy 3: Snake_case match targetSnake := toSnakeCase(targetField.Name) - for _, sf := range sourceStruct.Fields { + for _, sf := range availableSourceFields { if toSnakeCase(sf.Name) == targetSnake { return sf, true } @@ -169,15 +169,23 @@ func findSourceField( // Strategy 4: Position-based match (fallback when structs have same field count) // Only use position matching if the structs have the same number of fields - if len(sourceStruct.Fields) == len(targetStruct.Fields) && len(sourceStruct.Fields) > 0 { - // Match by position - use the field at the same index in source - if targetIdx < len(sourceStruct.Fields) { - sf := sourceStruct.Fields[targetIdx] - // Verify types are compatible - if fieldsCompatible(sf.Type, targetField.Type) { - return sf, true - } - } + if len(sourceStruct.Fields) != len(targetStruct.Fields) || len(sourceStruct.Fields) == 0 { + return FieldInfo{}, false + } + // Match by position - use the field at the same index in source + if targetIdx >= len(sourceStruct.Fields) { + return FieldInfo{}, false + } + + originalSourceField := sourceStruct.Fields[targetIdx] + // Check if it's still available + sf, ok := availableSourceFields[originalSourceField.Name] + if !ok { + return FieldInfo{}, false + } + // Verify types are compatible + if fieldsCompatible(sf.Type, targetField.Type) { + return sf, true } return FieldInfo{}, false @@ -257,10 +265,16 @@ func joinDomainStructParam( targetStruct := targetStructs[targetStructKey] + // Create a map of available source fields to track which fields have been mapped. + availableSourceFields := make(map[string]FieldInfo, len(sourceStruct.Fields)) + for _, sf := range sourceStruct.Fields { + availableSourceFields[sf.Name] = sf + } + var fields []string for targetIdx, targetField := range targetStruct.Fields { - sourceField, found := findSourceField(targetField, targetIdx, targetStruct, sourceStruct) + sourceField, found := findSourceField(targetField, targetIdx, targetStruct, sourceStruct, availableSourceFields) if found { conversion := generateFieldConversion( @@ -270,6 +284,8 @@ func joinDomainStructParam( fmt.Sprintf("%s.%s", param.Name, sourceField.Name), ) fields = append(fields, conversion) + // Remove the mapped field so it can't be used again. + delete(availableSourceFields, sourceField.Name) } }