diff --git a/README.md b/README.md index ceb734e..7c2a769 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,7 @@ SELECT unnest(@book_ids::bigint[]), unnest(@tag_ids::bigint[]); ``` The generator will: + - On **PostgreSQL**: delegate `AddBookTags` directly to the underlying sqlc implementation - On **SQLite/MySQL**: generate a loop that calls `AddBookTag` once per element diff --git a/example/pkg/database/database.go b/example/pkg/database/database.go index 2b6495a..41899a2 100644 --- a/example/pkg/database/database.go +++ b/example/pkg/database/database.go @@ -10,9 +10,9 @@ import ( "github.com/kalbasit/sqlc-multi-db/example/pkg/database/postgresdb" "github.com/kalbasit/sqlc-multi-db/example/pkg/database/sqlitedb" - _ "github.com/go-sql-driver/mysql" // MySQL driver - _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver - _ "github.com/mattn/go-sqlite3" // SQLite driver + _ "github.com/go-sql-driver/mysql" // MySQL driver + _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver + _ "github.com/mattn/go-sqlite3" // SQLite driver ) // Open opens a database connection and returns a Querier. diff --git a/example/pkg/database/errors.go b/example/pkg/database/errors.go index 7ab6320..42bde8f 100644 --- a/example/pkg/database/errors.go +++ b/example/pkg/database/errors.go @@ -8,10 +8,8 @@ import ( "github.com/mattn/go-sqlite3" ) -var ( - // ErrUnsupportedDriver is returned when the database driver is not recognized. - ErrUnsupportedDriver = errors.New("unsupported database driver") -) +// ErrUnsupportedDriver is returned when the database driver is not recognized. +var ErrUnsupportedDriver = errors.New("unsupported database driver") // IsDeadlockError checks if the error is a deadlock or "database busy" error. func IsDeadlockError(err error) bool { diff --git a/example/pkg/database/postgresdb/querier.go b/example/pkg/database/postgresdb/querier.go index 4536ae2..be00bae 100644 --- a/example/pkg/database/postgresdb/querier.go +++ b/example/pkg/database/postgresdb/querier.go @@ -16,7 +16,6 @@ type Querier interface { // @bulk-for AddBookTag AddBookTags(ctx context.Context, arg AddBookTagsParams) error - // CreateBook creates a new book. // INSERT INTO books ("title", "author", "description") VALUES ($1, $2, $3) RETURNING "id", "title", "author", "description", "created_at", "updated_at" CreateBook(ctx context.Context, arg CreateBookParams) (Book, error) diff --git a/generator/constants.go b/generator/constants.go index 00599ba..c56e17c 100644 --- a/generator/constants.go +++ b/generator/constants.go @@ -5,6 +5,8 @@ const ( typeAny = "interface{}" typeBool = "bool" typeString = "string" + typeInt = "int" + typeBytes = "[]byte" zeroNil = "nil" typeInt16 = "int16" typeInt32 = "int32" diff --git a/generator/generator_test.go b/generator/generator_test.go index dd8cb45..007e666 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -553,3 +553,80 @@ func TestGenerateFieldConversion(t *testing.T) { }) } } + +// TestJoinParamsCallFieldMapping tests that JoinParamsCall correctly maps +// struct fields even when field names differ between source and target. +// This is a regression test for the MySQL LIMIT parameter issue where +// sqlc generates different field names (e.g., BatchSize vs Limit). +func TestJoinParamsCallFieldMapping(t *testing.T) { + t.Parallel() + + // Source structs (domain) - what the wrapper API uses + sourceStructs := map[string]generator.StructInfo{ + "GetStuckNarFilesParams": { + Name: "GetStuckNarFilesParams", + Fields: []generator.FieldInfo{ + {Name: "CutoffTime", Type: "time.Time"}, + {Name: "BatchSize", Type: "int32"}, + }, + }, + } + + // Target structs (adapter) - what the database engine generates + // MySQL generates different names: CreatedAt instead of CutoffTime, Limit instead of BatchSize + targetStructs := map[string]generator.StructInfo{ + "GetStuckNarFilesParams": { + Name: "GetStuckNarFilesParams", + Fields: []generator.FieldInfo{ + {Name: "CreatedAt", Type: "time.Time"}, + {Name: "Limit", Type: "int32"}, + }, + }, + } + + // Target method info + targetMethod := generator.MethodInfo{ + Name: "GetStuckNarFiles", + Params: []generator.Param{ + {Name: "ctx", Type: "context.Context"}, + {Name: "arg", Type: "GetStuckNarFilesParams"}, + }, + } + + tests := []struct { + name string + params []generator.Param + engPkg string + want string + wantErr bool + }{ + { + name: "Field mapping with different names", + params: []generator.Param{ + {Name: "ctx", Type: "context.Context"}, + {Name: "arg", Type: "GetStuckNarFilesParams"}, + }, + engPkg: "mysqldb", + // Expected: both fields should be mapped even though names differ + // The target struct uses its own field names (CreatedAt, Limit), not source names + want: "ctx, mysqldb.GetStuckNarFilesParams{\nCreatedAt: arg.CutoffTime,\nLimit: arg.BatchSize,\n}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := generator.JoinParamsCall(tt.params, tt.engPkg, targetMethod, targetStructs, sourceStructs) + if (err != nil) != tt.wantErr { + t.Errorf("JoinParamsCall() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if got != tt.want { + t.Errorf("JoinParamsCall() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/generator/helpers.go b/generator/helpers.go index 84582a6..59dc104 100644 --- a/generator/helpers.go +++ b/generator/helpers.go @@ -134,6 +134,106 @@ func JoinParamsCall( return joinParamsCall(params, engPkg, targetMethod, targetStructs, sourceStructs) } +// findSourceField finds a matching field in available source fields using multiple strategies: +// 1. Exact name match +// 2. Case-insensitive match +// 3. Snake_case match +// 4. Position-based match (fallback when structs have same field count). +// The availableSourceFields map is modified to remove matched fields. +func findSourceField( + targetField FieldInfo, + targetIdx int, + targetStruct StructInfo, + sourceStruct StructInfo, + availableSourceFields map[string]FieldInfo, +) (FieldInfo, bool) { + // Strategy 1: Exact name match + if sf, ok := availableSourceFields[targetField.Name]; ok { + return sf, true + } + + // Strategy 2: Case-insensitive match + for _, sf := range availableSourceFields { + if strings.EqualFold(sf.Name, targetField.Name) { + return sf, true + } + } + + // Strategy 3: Snake_case match + targetSnake := toSnakeCase(targetField.Name) + for _, sf := range availableSourceFields { + if toSnakeCase(sf.Name) == targetSnake { + return sf, true + } + } + + // Strategy 4: Position-based match (fallback when structs have same field count) + // Only use position matching if the structs have the same number of fields + if len(sourceStruct.Fields) != len(targetStruct.Fields) || len(sourceStruct.Fields) == 0 { + return FieldInfo{}, false + } + // Match by position - use the field at the same index in source + if targetIdx >= len(sourceStruct.Fields) { + return FieldInfo{}, false + } + + originalSourceField := sourceStruct.Fields[targetIdx] + // Check if it's still available + sf, ok := availableSourceFields[originalSourceField.Name] + if !ok { + return FieldInfo{}, false + } + // Verify types are compatible + if fieldsCompatible(sf.Type, targetField.Type) { + return sf, true + } + + return FieldInfo{}, false +} + +// fieldsCompatible checks if two field types are compatible for mapping. +func fieldsCompatible(sourceType, targetType string) bool { + // Normalize types for comparison + sourceBase := normalizeType(sourceType) + targetBase := normalizeType(targetType) + + return sourceBase == targetBase +} + +// normalizeType normalizes a type string for comparison. +func normalizeType(t string) string { + // Remove common prefixes/suffixes + t = strings.TrimPrefix(t, "[]") + t = strings.TrimPrefix(t, "*") + + // Handle time types + if strings.Contains(t, "time.Time") || strings.Contains(t, "NullTime") { + return "time" + } + + // Handle numeric types + switch t { + case typeInt, "int8", "int16", "int32", "int64", + "uint", "uint8", "uint16", "uint32", "uint64", + sqlNullInt32, sqlNullInt64: + return typeInt + case "float32", "float64", sqlNullFloat64: + return "float" + case typeString, sqlNullString, typeBytes: + return typeString + case typeBool, sqlNullBool: + return typeBool + } + + // Remove package prefix if present + parts := strings.Split(t, ".") + if len(parts) > 1 { + return parts[len(parts)-1] + } + + return t +} + func joinDomainStructParam( param Param, i int, @@ -153,23 +253,28 @@ func joinDomainStructParam( if targetParamType != "" { sourceStruct := sourceStructs[param.Type] - targetStruct := targetStructs[targetParamType] - - var fields []string + // Target struct keys may include the package prefix (e.g., "mysqldb.GetStuckNarFilesParams") + // Try with prefix first, then without + targetStructKey := targetParamType + if engPkg != "" { + if _, ok := targetStructs[engPkg+"."+targetParamType]; ok { + targetStructKey = engPkg + "." + targetParamType + } + // Otherwise keep using targetParamType (no prefix) + } - for _, targetField := range targetStruct.Fields { - var sourceField FieldInfo + targetStruct := targetStructs[targetStructKey] - found := false + // Create a map of available source fields to track which fields have been mapped. + availableSourceFields := make(map[string]FieldInfo, len(sourceStruct.Fields)) + for _, sf := range sourceStruct.Fields { + availableSourceFields[sf.Name] = sf + } - for _, sf := range sourceStruct.Fields { - if sf.Name == targetField.Name { - sourceField = sf - found = true + var fields []string - break - } - } + for targetIdx, targetField := range targetStruct.Fields { + sourceField, found := findSourceField(targetField, targetIdx, targetStruct, sourceStruct, availableSourceFields) if found { conversion := generateFieldConversion( @@ -179,6 +284,8 @@ func joinDomainStructParam( fmt.Sprintf("%s.%s", param.Name, sourceField.Name), ) fields = append(fields, conversion) + // Remove the mapped field so it can't be used again. + delete(availableSourceFields, sourceField.Name) } }