Skip to content

Commit 4206712

Browse files
committed
fix(migrations): robustly strip psql meta commands without breaking SQL
Replace naive PostgreSQL schema preprocessing with a single-pass state machine that distinguishes top-level psql meta-commands from valid SQL backslashes, literals, identifiers, comments, and dollar-quoted bodies. The previous implementation could leave pg_dump/client backslash directives in schema-loading paths or strip too aggressively, breaking valid SQL containing: - Backslashes in string literals, including `E'...'` escapes and simple `standard_conforming_strings` variants - Meta-command text in comments or documentation - Dollar-quoted function bodies, including Unicode-tagged bodies - Double-quoted identifiers and identifiers containing `$` Changes: - Add engine-aware `PreprocessSchema()` and `PreprocessSchemaForApply()` helpers so rollback removal always applies while PostgreSQL psql stripping is mode-aware. - Replace line-based PostgreSQL filtering with a single-pass lexer that tracks single quotes, double quotes, dollar quotes, line comments, nested block comments, and statement boundaries. - Handle escape-string prefixes, simple `standard_conforming_strings` changes, Unicode dollar-quote tags, identifier-boundary checks, documented psql meta-commands, and broader unknown top-level backslash directives. - Preserve SQL after a valid inline `\\` separator that follows a meta-command, including glued and one-sided-whitespace forms observed in psql 13.22 / 14.19 / 15.14 / 16.10 / 17.6 / 17.10; preserve invalid leading `\\` input instead of normalizing it into SQL. - Strip semantic psql commands such as `\connect`, includes, `\copy`, `\gexec`, `\q`, `\quit`, and `\r` with warnings in parse/codegen paths, but reject them in schema-application paths where sqlc cannot reproduce their effects safely. - Reject psql conditionals (`\if`, `\elif`, `\else`, `\endif`) instead of flattening branches and changing SQL semantics. - Remove `\copy ... from stdin` payload rows through an exact `\.` terminator in parse mode, and reject unterminated copy data. - Treat `standard_conforming_strings` and transaction-scoped script behavior as best-effort parsing aids rather than full psql emulation; report approximation warnings in parse mode while suppressing that parse-only warning for live apply mode. - Wire preprocessing and warning propagation into compiler parsing, generate processing, `createdb`, `verify`, managed `vet`, and PostgreSQL sqltest seeding paths. - Add regression coverage for documented meta-commands, unknown directives, literals, comments, dollar quotes, inline separators, semantic warnings, apply-mode rejections, copy data, line endings, and managed/PostgreSQL preprocessing rollout. Performance improvements: - Pre-allocate output buffers with `strings.Builder.Grow()`. - Keep parsing single-pass rather than rescanning line slices. - Reuse engine-aware preprocessing helpers across schema-loading paths. Testing: - `go test ./internal/migrations ./internal/compiler ./internal/schemautil ./internal/cmd ./internal/sqltest/...`
1 parent 4011ae5 commit 4206712

15 files changed

Lines changed: 2130 additions & 37 deletions

internal/cmd/createdb.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,14 @@ func CreateDB(ctx context.Context, dir, filename, querySetName string, o *Option
8484
if err != nil {
8585
return fmt.Errorf("read file: %w", err)
8686
}
87-
ddl = append(ddl, migrations.RemoveRollbackStatements(string(contents)))
87+
ddlText, warnings, err := migrations.PreprocessSchemaForApply(string(contents), string(queryset.Engine))
88+
if err != nil {
89+
return err
90+
}
91+
for _, warning := range warnings {
92+
fmt.Fprintln(o.Stderr, warning)
93+
}
94+
ddl = append(ddl, ddlText)
8895
}
8996

9097
now := time.Now().UTC().UnixNano()

internal/cmd/generate.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,9 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C
265265
return nil, true
266266
}
267267
if err := c.ParseCatalog(sql.Schema); err != nil {
268+
for _, warning := range c.Warnings() {
269+
fmt.Fprintln(stderr, warning)
270+
}
268271
fmt.Fprintf(stderr, "# package %s\n", name)
269272
if parserErr, ok := err.(*multierr.Error); ok {
270273
for _, fileErr := range parserErr.Errs() {
@@ -275,6 +278,9 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C
275278
}
276279
return nil, true
277280
}
281+
for _, warning := range c.Warnings() {
282+
fmt.Fprintln(stderr, warning)
283+
}
278284
if debugDumpCatalog.Value() == "1" {
279285
debug.Dump(c.Catalog())
280286
}

internal/cmd/process.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,12 @@ func processQuerySets(ctx context.Context, rp ResultProcessor, conf *config.Conf
119119
if err := grp.Wait(); err != nil {
120120
return err
121121
}
122-
if errored {
123-
for i, _ := range stderrs {
124-
if _, err := io.Copy(stderr, &stderrs[i]); err != nil {
125-
return err
126-
}
122+
for i := range stderrs {
123+
if _, err := io.Copy(stderr, &stderrs[i]); err != nil {
124+
return err
127125
}
126+
}
127+
if errored {
128128
return fmt.Errorf("errored")
129129
}
130130
return nil

internal/cmd/verify.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,14 @@ func Verify(ctx context.Context, dir, filename string, opts *Options) error {
102102
if err != nil {
103103
return fmt.Errorf("read file: %w", err)
104104
}
105-
ddl = append(ddl, migrations.RemoveRollbackStatements(string(contents)))
105+
ddlText, warnings, err := migrations.PreprocessSchemaForApply(string(contents), string(current.Engine))
106+
if err != nil {
107+
return err
108+
}
109+
for _, warning := range warnings {
110+
fmt.Fprintln(stderr, warning)
111+
}
112+
ddl = append(ddl, ddlText)
106113
}
107114

108115
var codegen plugin.GenerateRequest

internal/cmd/vet.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,14 @@ func (c *checker) fetchDatabaseUri(ctx context.Context, s config.SQL) (string, f
439439
if err != nil {
440440
return "", cleanup, fmt.Errorf("read file: %w", err)
441441
}
442-
ddl = append(ddl, migrations.RemoveRollbackStatements(string(contents)))
442+
ddlText, warnings, err := migrations.PreprocessSchemaForApply(string(contents), string(s.Engine))
443+
if err != nil {
444+
return "", cleanup, err
445+
}
446+
for _, warning := range warnings {
447+
fmt.Fprintln(c.Stderr, warning)
448+
}
449+
ddl = append(ddl, ddlText)
443450
}
444451

445452
resp, err := c.Client.CreateDatabase(ctx, &dbmanager.CreateDatabaseRequest{
@@ -554,7 +561,13 @@ func (c *checker) checkSQL(ctx context.Context, s config.SQL) error {
554561
if err != nil {
555562
return fmt.Errorf("read schema file: %w", err)
556563
}
557-
ddl := migrations.RemoveRollbackStatements(string(contents))
564+
ddl, warnings, err := migrations.PreprocessSchemaForApply(string(contents), string(s.Engine))
565+
if err != nil {
566+
return err
567+
}
568+
for _, warning := range warnings {
569+
fmt.Fprintln(c.Stderr, warning)
570+
}
558571
if _, err := db.ExecContext(ctx, ddl); err != nil {
559572
return fmt.Errorf("apply schema %s: %w", schema, err)
560573
}

internal/compiler/compile.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,24 @@ func (c *Compiler) parseCatalog(schemas []string) error {
3838
merr.Add(filename, "", 0, err)
3939
continue
4040
}
41-
contents := migrations.RemoveRollbackStatements(string(blob))
42-
contents = migrations.RemovePsqlMetaCommands(contents)
41+
contents, warnings, err := migrations.PreprocessSchema(string(blob), string(c.conf.Engine))
42+
if err != nil {
43+
merr.Add(filename, string(blob), 0, err)
44+
continue
45+
}
46+
var applyContents string
47+
if c.usesManagedAnalyzer() {
48+
applyContents, _, err = migrations.PreprocessSchemaForApply(string(blob), string(c.conf.Engine))
49+
if err != nil {
50+
merr.Add(filename, string(blob), 0, err)
51+
continue
52+
}
53+
}
54+
c.warns = append(c.warns, warnings...)
4355
c.schema = append(c.schema, contents)
56+
if c.usesManagedAnalyzer() {
57+
c.applySchema = append(c.applySchema, applyContents)
58+
}
4459

4560
// In database-only mode, we parse the schema to validate syntax
4661
// but don't update the catalog - the database will be the source of truth
@@ -73,7 +88,7 @@ func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) {
7388

7489
// In database-only mode, initialize the database connection before parsing queries
7590
if c.databaseOnlyMode && c.analyzer != nil {
76-
if err := c.analyzer.EnsureConn(ctx, c.schema); err != nil {
91+
if err := c.analyzer.EnsureConn(ctx, c.analyzerMigrations()); err != nil {
7792
return nil, fmt.Errorf("failed to initialize database connection: %w", err)
7893
}
7994
}

internal/compiler/compile_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package compiler
2+
3+
import (
4+
"os"
5+
"path/filepath"
6+
"strings"
7+
"testing"
8+
9+
"github.com/sqlc-dev/sqlc/internal/config"
10+
"github.com/sqlc-dev/sqlc/internal/multierr"
11+
"github.com/sqlc-dev/sqlc/internal/opts"
12+
)
13+
14+
func TestParseCatalogManagedAnalyzerRejectsSemanticPsqlCommandsForApply(t *testing.T) {
15+
dir := t.TempDir()
16+
schema := filepath.Join(dir, "schema.sql")
17+
if err := os.WriteFile(schema, []byte("\\include extra.sql\nCREATE TABLE foo (id int);\n"), 0600); err != nil {
18+
t.Fatal(err)
19+
}
20+
21+
c, err := NewCompiler(config.SQL{
22+
Engine: config.EnginePostgreSQL,
23+
Schema: []string{schema},
24+
Database: &config.Database{
25+
Managed: true,
26+
},
27+
}, config.CombinedSettings{}, opts.Parser{})
28+
if err != nil {
29+
t.Fatal(err)
30+
}
31+
32+
err = c.ParseCatalog([]string{schema})
33+
if err == nil {
34+
t.Fatal("expected managed analyzer schema preprocessing to reject semantic psql command")
35+
}
36+
merr, ok := err.(*multierr.Error)
37+
if !ok || len(merr.Errs()) != 1 {
38+
t.Fatalf("expected one schema error, got %T: %v", err, err)
39+
}
40+
if !strings.Contains(merr.Errs()[0].Err.Error(), `psql meta-command \include is not supported`) {
41+
t.Fatalf("unexpected error: %v", err)
42+
}
43+
}

internal/compiler/engine.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ type Compiler struct {
2727
client dbmanager.Client
2828
selector selector
2929

30-
schema []string
30+
schema []string
31+
applySchema []string
32+
warns []string
3133

3234
// databaseOnlyMode indicates that the compiler should use database-only analysis
3335
// and skip building the internal catalog from schema files (analyzer.database: only)
@@ -125,6 +127,17 @@ func (c *Compiler) ParseCatalog(schema []string) error {
125127
return c.parseCatalog(schema)
126128
}
127129

130+
func (c *Compiler) usesManagedAnalyzer() bool {
131+
return c.analyzer != nil && c.conf.Database != nil && c.conf.Database.Managed
132+
}
133+
134+
func (c *Compiler) analyzerMigrations() []string {
135+
if c.usesManagedAnalyzer() {
136+
return c.applySchema
137+
}
138+
return c.schema
139+
}
140+
128141
func (c *Compiler) ParseQueries(queries []string, o opts.Parser) error {
129142
r, err := c.parseQueries(o)
130143
if err != nil {
@@ -138,6 +151,12 @@ func (c *Compiler) Result() *Result {
138151
return c.result
139152
}
140153

154+
// Warnings returns a copy of any non-fatal schema preprocessing warnings
155+
// collected while parsing the catalog.
156+
func (c *Compiler) Warnings() []string {
157+
return append([]string(nil), c.warns...)
158+
}
159+
141160
func (c *Compiler) Close(ctx context.Context) {
142161
if c.analyzer != nil {
143162
c.analyzer.Close(ctx)

internal/compiler/parse.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
9393
expandedRaw := expandedStmts[0].Raw
9494

9595
// Use the analyzer to get type information from the database
96-
result, err := c.analyzer.Analyze(ctx, expandedRaw, expandedQuery, c.schema, nil)
96+
result, err := c.analyzer.Analyze(ctx, expandedRaw, expandedQuery, c.analyzerMigrations(), nil)
9797
if err != nil {
9898
return nil, err
9999
}
@@ -132,7 +132,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
132132
inference.Query = rawSQL
133133
}
134134

135-
result, err := c.analyzer.Analyze(ctx, raw, inference.Query, c.schema, inference.Named)
135+
result, err := c.analyzer.Analyze(ctx, raw, inference.Query, c.analyzerMigrations(), inference.Named)
136136
if err != nil {
137137
return nil, err
138138
}

0 commit comments

Comments
 (0)