Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg/formatter/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ func FormatStatement(s ast.Statement, opts ast.FormatOptions) string {
return renderShow(v, opts)
case *ast.DescribeStatement:
return renderDescribe(v, opts)
case *ast.UnsupportedStatement:
return renderUnsupported(v, opts)
default:
// Fallback to SQL() for unrecognized statement types
return stmtSQL(s)
Expand Down
10 changes: 10 additions & 0 deletions pkg/formatter/render_ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ func renderDescribe(s *ast.DescribeStatement, opts ast.FormatOptions) string {
return sb.String()
}

// renderUnsupported renders an UnsupportedStatement as a SQL comment
// preserving the original SQL fragment. This prevents silently producing
// corrupt SQL for statement types the formatter cannot handle.
func renderUnsupported(s *ast.UnsupportedStatement, _ ast.FormatOptions) string {
if s.RawSQL != "" {
return "-- UNSUPPORTED: " + s.RawSQL
}
return "-- UNSUPPORTED: " + s.Kind
}

// writeSequenceOptionsFormatted appends formatted sequence options to the builder.
// It mirrors the logic in ast/sql.go writeSequenceOptions but uses the nodeFormatter
// for keyword casing.
Expand Down
66 changes: 5 additions & 61 deletions pkg/gosqlx/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,68 +43,12 @@
//
// For large ASTs (1000+ nodes), expect extraction times <100μs on modern hardware.
//
// # Parser Limitations
// # Supported Expression Types
//
// The extraction functions in this package are subject to the following parser limitations.
// These limitations represent SQL features that are partially supported or not yet fully
// implemented in the GoSQLX parser. As the parser evolves, these limitations may be
// addressed in future releases.
//
// ## Known Limitations
//
// 1. CASE Expressions:
// CASE expressions (simple and searched CASE) are not fully supported in the parser.
// Column references within CASE WHEN conditions and result expressions may not be
// extracted correctly.
//
// Example (not fully supported):
// SELECT CASE status WHEN 'active' THEN name ELSE 'N/A' END FROM users
//
// 2. CAST Expressions:
// CAST expressions for type conversion are not fully supported. Column references
// within CAST expressions may not be extracted.
//
// Example (not fully supported):
// SELECT CAST(price AS DECIMAL(10,2)) FROM products
//
// 3. IN Expressions:
// IN expressions with subqueries or complex value lists in WHERE clauses are not
// fully supported. Column references in IN lists may not be extracted correctly.
//
// Example (not fully supported):
// SELECT * FROM users WHERE status IN ('active', 'pending')
// SELECT * FROM orders WHERE user_id IN (SELECT id FROM users)
//
// 4. BETWEEN Expressions:
// BETWEEN expressions for range comparisons are not fully supported. Column references
// in BETWEEN bounds may not be extracted correctly.
//
// Example (not fully supported):
// SELECT * FROM products WHERE price BETWEEN min_price AND max_price
//
// 5. Complex Recursive CTEs:
// Recursive Common Table Expressions (CTEs) with complex JOIN syntax are not fully
// supported. Simple recursive CTEs work, but complex variations may fail to parse.
//
// Example (not fully supported):
// WITH RECURSIVE org_chart AS (
// SELECT id, name, manager_id, 1 as level FROM employees WHERE manager_id IS NULL
// UNION ALL
// SELECT e.id, e.name, e.manager_id, o.level + 1
// FROM employees e
// INNER JOIN org_chart o ON e.manager_id = o.id
// )
// SELECT * FROM org_chart
//
// ## Workarounds
//
// For queries using these unsupported features:
// - Simplify complex expressions where possible
// - Use alternative SQL syntax that is supported
// - Extract metadata manually from the original SQL string
// - Consider contributing parser enhancements to the GoSQLX project
//
// ## Reporting Issues
// The extraction functions correctly traverse all standard expression types
// including: CASE (simple and searched), CAST, IN, BETWEEN, EXTRACT,
// SUBSTRING, POSITION, subqueries, recursive CTEs with JOINs, window
// functions, and all arithmetic/logical operators.
//
// If you encounter parsing issues with SQL queries that should be supported,
// please report them at: https://github.com/ajitpratap0/GoSQLX/issues
Expand Down
100 changes: 85 additions & 15 deletions pkg/gosqlx/extract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,27 @@ func TestExtractTables_WithCTE(t *testing.T) {
}

func TestExtractTables_WithRecursiveCTE(t *testing.T) {
// Skipping - Recursive CTE with complex JOIN syntax not fully supported yet
t.Skip("Recursive CTE with complex syntax not fully supported")
sql := `WITH RECURSIVE org_chart AS (
SELECT id, name, manager_id, 1 as level FROM employees WHERE manager_id IS NULL
UNION ALL
SELECT e.id, e.name, e.manager_id, o.level + 1
FROM employees e
INNER JOIN org_chart o ON e.manager_id = o.id
)
SELECT * FROM org_chart`

astNode, err := Parse(sql)
if err != nil {
t.Fatalf("Failed to parse recursive CTE: %v", err)
}

tables := ExtractTables(astNode)
if !contains(tables, "employees") {
t.Errorf("Expected to find 'employees' table, got: %v", tables)
}
if !contains(tables, "org_chart") {
t.Errorf("Expected to find 'org_chart' CTE reference, got: %v", tables)
}
}

func TestExtractTables_Insert(t *testing.T) {
Expand Down Expand Up @@ -647,23 +666,58 @@ func TestExtractMetadata_EmptyQuery(t *testing.T) {
}

func TestExtractColumns_WithCaseExpression(t *testing.T) {
// Skipping - CASE expressions not fully supported in parser yet
t.Skip("CASE expressions not fully supported in parser yet")
sql := "SELECT CASE status WHEN 'active' THEN name ELSE 'N/A' END FROM users"
astNode, err := Parse(sql)
if err != nil {
t.Fatalf("Failed to parse SQL: %v", err)
}

cols := ExtractColumns(astNode)
if !contains(cols, "status") {
t.Errorf("Expected 'status' in columns, got: %v", cols)
}
if !contains(cols, "name") {
t.Errorf("Expected 'name' in columns, got: %v", cols)
}
}

func TestExtractColumns_WithInExpression(t *testing.T) {
// Skipping - IN expressions in WHERE clause not fully supported yet
t.Skip("IN expressions in WHERE clause not fully supported yet")
sql := "SELECT * FROM users WHERE status IN ('active', 'pending')"
astNode, err := Parse(sql)
if err != nil {
t.Fatalf("Failed to parse SQL: %v", err)
}

cols := ExtractColumns(astNode)
if !contains(cols, "status") {
t.Errorf("Expected 'status' in columns, got: %v", cols)
}
}

func TestExtractColumns_WithBetweenExpression(t *testing.T) {
// Skipping - BETWEEN expressions not fully supported yet
t.Skip("BETWEEN expressions not fully supported yet")
sql := "SELECT * FROM products WHERE price BETWEEN 10 AND 100"
astNode, err := Parse(sql)
if err != nil {
t.Fatalf("Failed to parse SQL: %v", err)
}

cols := ExtractColumns(astNode)
if !contains(cols, "price") {
t.Errorf("Expected 'price' in columns, got: %v", cols)
}
}

func TestExtractFunctions_InCaseExpression(t *testing.T) {
// Skipping - CASE expressions not fully supported yet
t.Skip("CASE expressions not fully supported yet")
sql := "SELECT CASE WHEN COUNT(*) > 0 THEN 'yes' ELSE 'no' END FROM users"
astNode, err := Parse(sql)
if err != nil {
t.Fatalf("Failed to parse SQL: %v", err)
}

funcs := ExtractFunctions(astNode)
if !contains(funcs, "COUNT") {
t.Errorf("Expected 'COUNT' in functions, got: %v", funcs)
}
}

func TestExtractTables_WithSetOperations(t *testing.T) {
Expand Down Expand Up @@ -854,11 +908,27 @@ func TestExtractFunctions_NoFunctions(t *testing.T) {
}

func TestExtractColumns_WithCastExpression(t *testing.T) {
// Skipping - CAST expressions not fully supported yet
t.Skip("CAST expressions not fully supported yet")
sql := "SELECT CAST(price AS DECIMAL) FROM products"
astNode, err := Parse(sql)
if err != nil {
t.Fatalf("Failed to parse SQL: %v", err)
}

cols := ExtractColumns(astNode)
if !contains(cols, "price") {
t.Errorf("Expected 'price' in columns, got: %v", cols)
}
}

func TestExtractFunctions_ExtractExpression(t *testing.T) {
// Skipping - EXTRACT expressions not fully supported yet
t.Skip("EXTRACT expressions not fully supported yet")
func TestExtractColumns_WithExtractExpression(t *testing.T) {
sql := "SELECT EXTRACT(YEAR FROM created_at) FROM orders"
astNode, err := Parse(sql)
if err != nil {
t.Fatalf("Failed to parse SQL: %v", err)
}

cols := ExtractColumns(astNode)
if !contains(cols, "created_at") {
t.Errorf("Expected 'created_at' in columns, got: %v", cols)
}
}
23 changes: 20 additions & 3 deletions pkg/gosqlx/gosqlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,32 @@ func Validate(sql string) error {

// ParseBytes is like Parse but accepts a byte slice.
//
// This is useful when you already have SQL as bytes (e.g., from file I/O)
// and want to avoid the string → []byte conversion overhead.
// This avoids the string-to-byte conversion that Parse performs internally,
// making it more efficient when you already have SQL as bytes (e.g., from
// file I/O or network reads).
//
// Example:
//
// sqlBytes := []byte("SELECT * FROM users")
// astNode, err := gosqlx.ParseBytes(sqlBytes)
func ParseBytes(sql []byte) (*ast.AST, error) {
return Parse(string(sql))
tkz := tokenizer.GetTokenizer()
defer tokenizer.PutTokenizer(tkz)

tokens, err := tkz.Tokenize(sql)
if err != nil {
return nil, fmt.Errorf("tokenization failed: %w", err)
}

p := parser.GetParser()
defer parser.PutParser(p)

astNode, err := p.ParseFromModelTokens(tokens)
if err != nil {
return nil, fmt.Errorf("parsing failed: %w", err)
}

return astNode, nil
}

// MustParse is like Parse but panics on error.
Expand Down
29 changes: 29 additions & 0 deletions pkg/sql/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -1892,6 +1892,18 @@ func (a AST) Children() []Node {
return children
}

// HasUnsupportedStatements returns true if the AST contains any
// UnsupportedStatement nodes — statements the parser consumed but
// could not fully model.
func (a AST) HasUnsupportedStatements() bool {
for _, stmt := range a.Statements {
if _, ok := stmt.(*UnsupportedStatement); ok {
return true
}
}
return false
}

// PragmaStatement represents a SQLite PRAGMA statement.
// Examples: PRAGMA table_info(users), PRAGMA journal_mode = WAL, PRAGMA integrity_check
type PragmaStatement struct {
Expand Down Expand Up @@ -1924,6 +1936,23 @@ func (d *DescribeStatement) statementNode() {}
func (d DescribeStatement) TokenLiteral() string { return "DESCRIBE" }
func (d DescribeStatement) Children() []Node { return nil }

// UnsupportedStatement represents a SQL statement that was parsed but not
// fully modeled in the AST. The parser consumed and validated the tokens
// but no dedicated AST node exists yet for this statement kind.
//
// Consumers should use Kind to identify the operation (e.g., "USE", "COPY",
// "CREATE STAGE") and RawSQL for the original text. Tools that do
// switch stmt.(type) should handle this case explicitly rather than
// falling through to a default that assumes the statement is well-structured.
type UnsupportedStatement struct {
Kind string // Operation kind: "USE", "COPY", "PUT", "GET", "LIST", "REMOVE", "CREATE STAGE", etc.
RawSQL string // Original SQL fragment for round-trip fidelity
}

func (u *UnsupportedStatement) statementNode() {}
func (u UnsupportedStatement) TokenLiteral() string { return u.Kind }
func (u UnsupportedStatement) Children() []Node { return nil }

// ReplaceStatement represents MySQL REPLACE INTO statement
type ReplaceStatement struct {
TableName string
Expand Down
25 changes: 25 additions & 0 deletions pkg/sql/ast/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ var (
},
}

unsupportedStmtPool = sync.Pool{
New: func() interface{} {
return &UnsupportedStatement{}
},
}

replaceStmtPool = sync.Pool{
New: func() interface{} {
return &ReplaceStatement{
Expand Down Expand Up @@ -535,6 +541,8 @@ func releaseStatement(stmt Statement) {
PutShowStatement(s)
case *DescribeStatement:
PutDescribeStatement(s)
case *UnsupportedStatement:
PutUnsupportedStatement(s)
case *ReplaceStatement:
PutReplaceStatement(s)
case *AlterStatement:
Expand Down Expand Up @@ -1756,6 +1764,23 @@ func PutDescribeStatement(stmt *DescribeStatement) {
describeStmtPool.Put(stmt)
}

// GetUnsupportedStatement gets an UnsupportedStatement from the pool.
func GetUnsupportedStatement() *UnsupportedStatement {
return unsupportedStmtPool.Get().(*UnsupportedStatement)
}

// PutUnsupportedStatement returns an UnsupportedStatement to the pool.
func PutUnsupportedStatement(stmt *UnsupportedStatement) {
if stmt == nil {
return
}

stmt.Kind = ""
stmt.RawSQL = ""

unsupportedStmtPool.Put(stmt)
}

// GetReplaceStatement gets a ReplaceStatement from the pool.
func GetReplaceStatement() *ReplaceStatement {
stmt := replaceStmtPool.Get().(*ReplaceStatement)
Expand Down
Loading
Loading