Skip to content

Commit 8975bab

Browse files
committed
feat: improve schema.sql output
1 parent 3f43a66 commit 8975bab

1 file changed

Lines changed: 161 additions & 43 deletions

File tree

cmd/substreams-sink-sql/from_proto_generate_csv.go

Lines changed: 161 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -311,67 +311,185 @@ func fromProtoGenerateCsvE(cmd *cobra.Command, args []string) error {
311311

312312
// exportSQLSchema exports the exact SQL DDL that from-proto would execute
313313
func exportSQLSchema(dialect sql.Dialect, sqlSchema *schema.Schema, useConstraints bool, driver, outputPath string) error {
314-
var statements []string
314+
// Extract BaseDialect from the dialect
315+
var baseDialect *sql.BaseDialect
316+
switch d := dialect.(type) {
317+
case *postgres.DialectPostgres:
318+
baseDialect = d.BaseDialect
319+
case *risingwave.DialectRisingwave:
320+
baseDialect = d.BaseDialect
321+
case *clickhouse.DialectClickHouse:
322+
baseDialect = d.BaseDialect
323+
default:
324+
return fmt.Errorf("unsupported dialect type %T", dialect)
325+
}
315326

316-
// Add static SQL (system tables)
317-
staticSQL := getStaticSQL(driver, sqlSchema.Name)
318-
if staticSQL != "" {
319-
statements = append(statements, staticSQL)
320-
}
327+
var b strings.Builder
321328

322-
// Extract BaseDialect from the dialect
323-
var baseDialect *sql.BaseDialect
324-
switch d := dialect.(type) {
325-
case *postgres.DialectPostgres:
326-
baseDialect = d.BaseDialect
327-
case *risingwave.DialectRisingwave:
328-
baseDialect = d.BaseDialect
329-
case *clickhouse.DialectClickHouse:
330-
baseDialect = d.BaseDialect
331-
default:
332-
return fmt.Errorf("unsupported dialect type %T", dialect)
333-
}
329+
// Header
330+
b.WriteString("-- Generated by substreams-sink-sql from-proto-generate-csv\n")
331+
b.WriteString(fmt.Sprintf("-- Dialect: %s\n", driver))
332+
b.WriteString(fmt.Sprintf("-- Schema: %s\n", sqlSchema.Name))
333+
b.WriteString(fmt.Sprintf("-- Generated at: %s\n", time.Now().UTC().Format(time.RFC3339)))
334+
b.WriteString(fmt.Sprintf("-- Schema hash: %s\n\n", dialect.SchemaHash()))
334335

335-
// Sort table names for consistent ordering
336-
var tableNames []string
337-
for name := range baseDialect.CreateTableSql {
338-
tableNames = append(tableNames, name)
339-
}
340-
sort.Strings(tableNames)
336+
// System tables
337+
if staticSQL := strings.TrimSpace(getStaticSQL(driver, sqlSchema.Name)); staticSQL != "" {
338+
b.WriteString("-- System Tables\n")
339+
b.WriteString(staticSQL)
340+
b.WriteString("\n\n")
341+
}
341342

342-
for _, tableName := range tableNames {
343-
createSQL := baseDialect.CreateTableSql[tableName]
344-
statements = append(statements, createSQL)
343+
// User tables
344+
// Sort table names for consistent ordering
345+
var tableNames []string
346+
for name := range baseDialect.CreateTableSql {
347+
tableNames = append(tableNames, name)
348+
}
349+
sort.Strings(tableNames)
350+
351+
if len(tableNames) > 0 {
352+
b.WriteString("-- User Tables\n")
353+
for i, tableName := range tableNames {
354+
createSQL := strings.TrimSpace(baseDialect.CreateTableSql[tableName])
355+
b.WriteString(prettyFormatCreateTable(createSQL))
356+
if i < len(tableNames)-1 {
357+
b.WriteString("\n\n")
358+
} else {
359+
b.WriteString("\n\n")
360+
}
361+
}
345362
}
346363

347-
// Add constraints if enabled
364+
// Constraints (optional)
348365
if useConstraints {
349-
// Primary keys
350-
for _, constraint := range baseDialect.PrimaryKeySql {
351-
statements = append(statements, constraint.Sql)
366+
if len(baseDialect.PrimaryKeySql) > 0 || len(baseDialect.UniqueConstraintSql) > 0 || len(baseDialect.ForeignKeySql) > 0 {
367+
b.WriteString("-- Constraints\n")
352368
}
353-
354-
// Unique constraints
355-
for _, constraint := range baseDialect.UniqueConstraintSql {
356-
statements = append(statements, constraint.Sql)
369+
for _, c := range baseDialect.PrimaryKeySql {
370+
b.WriteString(strings.TrimSpace(c.Sql))
371+
b.WriteString("\n")
357372
}
358-
359-
// Foreign keys
360-
for _, constraint := range baseDialect.ForeignKeySql {
361-
statements = append(statements, constraint.Sql)
373+
for _, c := range baseDialect.UniqueConstraintSql {
374+
b.WriteString(strings.TrimSpace(c.Sql))
375+
b.WriteString("\n")
376+
}
377+
for _, c := range baseDialect.ForeignKeySql {
378+
b.WriteString(strings.TrimSpace(c.Sql))
379+
b.WriteString("\n")
380+
}
381+
if len(baseDialect.PrimaryKeySql) > 0 || len(baseDialect.UniqueConstraintSql) > 0 || len(baseDialect.ForeignKeySql) > 0 {
382+
b.WriteString("\n")
362383
}
363384
}
364385

365386
// Seed sink info so from-proto can start without error
366387
switch driver {
367388
case "postgres", "risingwave":
368389
seed := fmt.Sprintf("INSERT INTO \"%s\".\"_sink_info_\" (schema_hash) VALUES ('%s') ON CONFLICT (schema_hash) DO NOTHING;", sqlSchema.Name, dialect.SchemaHash())
369-
statements = append(statements, seed)
390+
b.WriteString("-- Seed\n")
391+
b.WriteString(seed)
392+
b.WriteString("\n")
370393
}
371394

372-
// Write to file
373-
content := strings.Join(statements, "\n\n") + "\n"
374-
return os.WriteFile(outputPath, []byte(content), 0644)
395+
// Write to file
396+
return os.WriteFile(outputPath, []byte(b.String()), 0644)
397+
}
398+
399+
// prettyFormatCreateTable formats a single CREATE TABLE statement so that
400+
// columns/indexes inside the parentheses each appear on their own line.
401+
// It avoids breaking numeric type parameters like NUMERIC(78,0) by only
402+
// splitting on commas at depth 0 within the column list, and skips commas
403+
// inside quoted identifiers.
404+
func prettyFormatCreateTable(sql string) string {
405+
// Find first opening parenthesis after CREATE TABLE
406+
idxOpen := strings.Index(sql, "(")
407+
if idxOpen == -1 {
408+
return sql
409+
}
410+
411+
// Find the matching closing parenthesis for the column list
412+
depth := 0
413+
inQuotes := false
414+
idxClose := -1
415+
for i := idxOpen; i < len(sql); i++ {
416+
ch := sql[i]
417+
if ch == '"' {
418+
// Toggle double-quote state; basic handling (no escape handling required for our generated SQL)
419+
inQuotes = !inQuotes
420+
}
421+
if inQuotes {
422+
continue
423+
}
424+
if ch == '(' {
425+
depth++
426+
} else if ch == ')' {
427+
depth--
428+
if depth == 0 {
429+
idxClose = i
430+
break
431+
}
432+
}
433+
}
434+
if idxClose == -1 {
435+
return sql
436+
}
437+
438+
prefix := sql[:idxOpen+1]
439+
columnsSeg := sql[idxOpen+1 : idxClose]
440+
suffix := sql[idxClose:]
441+
442+
// Split columns at top-level commas
443+
var parts []string
444+
var cur strings.Builder
445+
depth = 0
446+
inQuotes = false
447+
for i := 0; i < len(columnsSeg); i++ {
448+
ch := columnsSeg[i]
449+
if ch == '"' {
450+
inQuotes = !inQuotes
451+
cur.WriteByte(ch)
452+
continue
453+
}
454+
if !inQuotes {
455+
if ch == '(' {
456+
depth++
457+
} else if ch == ')' {
458+
if depth > 0 {
459+
depth--
460+
}
461+
} else if ch == ',' && depth == 0 {
462+
part := strings.TrimSpace(cur.String())
463+
if part != "" {
464+
parts = append(parts, part)
465+
}
466+
cur.Reset()
467+
continue
468+
}
469+
}
470+
cur.WriteByte(ch)
471+
}
472+
if s := strings.TrimSpace(cur.String()); s != "" {
473+
parts = append(parts, s)
474+
}
475+
476+
// Rebuild with one part per line
477+
var b strings.Builder
478+
b.WriteString(prefix)
479+
if len(parts) > 0 {
480+
b.WriteString("\n")
481+
for i, p := range parts {
482+
b.WriteString(" ")
483+
b.WriteString(p)
484+
if i < len(parts)-1 {
485+
b.WriteString(",\n")
486+
} else {
487+
b.WriteString("\n")
488+
}
489+
}
490+
}
491+
b.WriteString(suffix)
492+
return b.String()
375493
}
376494

377495
// getStaticSQL returns the static SQL for system tables

0 commit comments

Comments
 (0)