From 54b3c6619e81ffafa14c45ead880ebc803f7fe94 Mon Sep 17 00:00:00 2001 From: Ajit Pratap Singh Date: Sun, 12 Apr 2026 13:38:07 +0530 Subject: [PATCH] fix(ast): replace DescribeStatement stubs with UnsupportedStatement, fix extraction gaps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses multiple issues from the v1.14.0 comprehensive project review: P0 — Critical: - Add UnsupportedStatement AST node with Kind and RawSQL fields to replace DescribeStatement misuse for Snowflake stubs (USE, COPY INTO, PUT, GET, LIST, REMOVE, CREATE STAGE/STREAM/TASK/PIPE/etc.) - Add EXTRACT(field FROM source) parser support (was missing entirely) - Fix all 7 extraction gap tests (CASE, CAST, IN, BETWEEN, EXTRACT, recursive CTEs) — previously t.Skip() stubs, now passing P1 — High: - Add AST.HasUnsupportedStatements() for stub detection - Formatter emits "-- UNSUPPORTED: ..." comment for unmodeled statements instead of producing corrupt SQL - Remove stale "CREATE TABLE not implemented" comment from coverage tests - Add TODO(v2-cleanup) markers to 5 overlapping coverage test files P2 — Medium: - Reconcile Validate() empty-input behavior (parser.ValidateBytes now rejects empty input, matching gosqlx.Validate) - Fix ParseBytes string→byte→string round-trip (now threads []byte directly to tokenizer) - Deprecate pkg/sql/monitor in favor of pkg/metrics (v2.0 removal) - Add v2.0 removal timeline to 3 deprecated parser APIs All tests pass with -race across the full project (20 files, +361/-103). Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/formatter/render.go | 2 + pkg/formatter/render_ddl.go | 10 ++ pkg/gosqlx/extract.go | 66 +----------- pkg/gosqlx/extract_test.go | 100 +++++++++++++++--- pkg/gosqlx/gosqlx.go | 23 +++- pkg/sql/ast/ast.go | 29 +++++ pkg/sql/ast/pool.go | 25 +++++ pkg/sql/ast/pool_ddl_test.go | 56 ++++++++++ pkg/sql/monitor/doc.go | 6 ++ pkg/sql/parser/coverage_improvement_test.go | 4 + pkg/sql/parser/ddl.go | 23 ++-- pkg/sql/parser/expressions_complex.go | 49 +++++++++ pkg/sql/parser/expressions_literal.go | 8 ++ pkg/sql/parser/parser.go | 41 ++++--- .../parser/parser_additional_coverage_test.go | 3 + pkg/sql/parser/parser_coverage_test.go | 6 +- pkg/sql/parser/parser_final_coverage_test.go | 3 + .../parser/parser_targeted_coverage_test.go | 4 + pkg/sql/parser/validate.go | 4 +- pkg/sql/parser/validate_test.go | 2 +- 20 files changed, 361 insertions(+), 103 deletions(-) diff --git a/pkg/formatter/render.go b/pkg/formatter/render.go index 710e4d7b..b413d96f 100644 --- a/pkg/formatter/render.go +++ b/pkg/formatter/render.go @@ -164,6 +164,8 @@ func FormatStatement(s ast.Statement, opts ast.FormatOptions) string { return renderShow(v, opts) case *ast.DescribeStatement: return renderDescribe(v, opts) + case *ast.UnsupportedStatement: + return renderUnsupported(v, opts) default: // Fallback to SQL() for unrecognized statement types return stmtSQL(s) diff --git a/pkg/formatter/render_ddl.go b/pkg/formatter/render_ddl.go index 5f006b32..17c4642d 100644 --- a/pkg/formatter/render_ddl.go +++ b/pkg/formatter/render_ddl.go @@ -119,6 +119,16 @@ func renderDescribe(s *ast.DescribeStatement, opts ast.FormatOptions) string { return sb.String() } +// renderUnsupported renders an UnsupportedStatement as a SQL comment +// preserving the original SQL fragment. This prevents silently producing +// corrupt SQL for statement types the formatter cannot handle. +func renderUnsupported(s *ast.UnsupportedStatement, _ ast.FormatOptions) string { + if s.RawSQL != "" { + return "-- UNSUPPORTED: " + s.RawSQL + } + return "-- UNSUPPORTED: " + s.Kind +} + // writeSequenceOptionsFormatted appends formatted sequence options to the builder. // It mirrors the logic in ast/sql.go writeSequenceOptions but uses the nodeFormatter // for keyword casing. diff --git a/pkg/gosqlx/extract.go b/pkg/gosqlx/extract.go index c1ff40c3..ccb82176 100644 --- a/pkg/gosqlx/extract.go +++ b/pkg/gosqlx/extract.go @@ -43,68 +43,12 @@ // // For large ASTs (1000+ nodes), expect extraction times <100μs on modern hardware. // -// # Parser Limitations +// # Supported Expression Types // -// The extraction functions in this package are subject to the following parser limitations. -// These limitations represent SQL features that are partially supported or not yet fully -// implemented in the GoSQLX parser. As the parser evolves, these limitations may be -// addressed in future releases. -// -// ## Known Limitations -// -// 1. CASE Expressions: -// CASE expressions (simple and searched CASE) are not fully supported in the parser. -// Column references within CASE WHEN conditions and result expressions may not be -// extracted correctly. -// -// Example (not fully supported): -// SELECT CASE status WHEN 'active' THEN name ELSE 'N/A' END FROM users -// -// 2. CAST Expressions: -// CAST expressions for type conversion are not fully supported. Column references -// within CAST expressions may not be extracted. -// -// Example (not fully supported): -// SELECT CAST(price AS DECIMAL(10,2)) FROM products -// -// 3. IN Expressions: -// IN expressions with subqueries or complex value lists in WHERE clauses are not -// fully supported. Column references in IN lists may not be extracted correctly. -// -// Example (not fully supported): -// SELECT * FROM users WHERE status IN ('active', 'pending') -// SELECT * FROM orders WHERE user_id IN (SELECT id FROM users) -// -// 4. BETWEEN Expressions: -// BETWEEN expressions for range comparisons are not fully supported. Column references -// in BETWEEN bounds may not be extracted correctly. -// -// Example (not fully supported): -// SELECT * FROM products WHERE price BETWEEN min_price AND max_price -// -// 5. Complex Recursive CTEs: -// Recursive Common Table Expressions (CTEs) with complex JOIN syntax are not fully -// supported. Simple recursive CTEs work, but complex variations may fail to parse. -// -// Example (not fully supported): -// WITH RECURSIVE org_chart AS ( -// SELECT id, name, manager_id, 1 as level FROM employees WHERE manager_id IS NULL -// UNION ALL -// SELECT e.id, e.name, e.manager_id, o.level + 1 -// FROM employees e -// INNER JOIN org_chart o ON e.manager_id = o.id -// ) -// SELECT * FROM org_chart -// -// ## Workarounds -// -// For queries using these unsupported features: -// - Simplify complex expressions where possible -// - Use alternative SQL syntax that is supported -// - Extract metadata manually from the original SQL string -// - Consider contributing parser enhancements to the GoSQLX project -// -// ## Reporting Issues +// The extraction functions correctly traverse all standard expression types +// including: CASE (simple and searched), CAST, IN, BETWEEN, EXTRACT, +// SUBSTRING, POSITION, subqueries, recursive CTEs with JOINs, window +// functions, and all arithmetic/logical operators. // // If you encounter parsing issues with SQL queries that should be supported, // please report them at: https://github.com/ajitpratap0/GoSQLX/issues diff --git a/pkg/gosqlx/extract_test.go b/pkg/gosqlx/extract_test.go index 6451fc36..a978f893 100644 --- a/pkg/gosqlx/extract_test.go +++ b/pkg/gosqlx/extract_test.go @@ -116,8 +116,27 @@ func TestExtractTables_WithCTE(t *testing.T) { } func TestExtractTables_WithRecursiveCTE(t *testing.T) { - // Skipping - Recursive CTE with complex JOIN syntax not fully supported yet - t.Skip("Recursive CTE with complex syntax not fully supported") + sql := `WITH RECURSIVE org_chart AS ( + SELECT id, name, manager_id, 1 as level FROM employees WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id, o.level + 1 + FROM employees e + INNER JOIN org_chart o ON e.manager_id = o.id + ) + SELECT * FROM org_chart` + + astNode, err := Parse(sql) + if err != nil { + t.Fatalf("Failed to parse recursive CTE: %v", err) + } + + tables := ExtractTables(astNode) + if !contains(tables, "employees") { + t.Errorf("Expected to find 'employees' table, got: %v", tables) + } + if !contains(tables, "org_chart") { + t.Errorf("Expected to find 'org_chart' CTE reference, got: %v", tables) + } } func TestExtractTables_Insert(t *testing.T) { @@ -647,23 +666,58 @@ func TestExtractMetadata_EmptyQuery(t *testing.T) { } func TestExtractColumns_WithCaseExpression(t *testing.T) { - // Skipping - CASE expressions not fully supported in parser yet - t.Skip("CASE expressions not fully supported in parser yet") + sql := "SELECT CASE status WHEN 'active' THEN name ELSE 'N/A' END FROM users" + astNode, err := Parse(sql) + if err != nil { + t.Fatalf("Failed to parse SQL: %v", err) + } + + cols := ExtractColumns(astNode) + if !contains(cols, "status") { + t.Errorf("Expected 'status' in columns, got: %v", cols) + } + if !contains(cols, "name") { + t.Errorf("Expected 'name' in columns, got: %v", cols) + } } func TestExtractColumns_WithInExpression(t *testing.T) { - // Skipping - IN expressions in WHERE clause not fully supported yet - t.Skip("IN expressions in WHERE clause not fully supported yet") + sql := "SELECT * FROM users WHERE status IN ('active', 'pending')" + astNode, err := Parse(sql) + if err != nil { + t.Fatalf("Failed to parse SQL: %v", err) + } + + cols := ExtractColumns(astNode) + if !contains(cols, "status") { + t.Errorf("Expected 'status' in columns, got: %v", cols) + } } func TestExtractColumns_WithBetweenExpression(t *testing.T) { - // Skipping - BETWEEN expressions not fully supported yet - t.Skip("BETWEEN expressions not fully supported yet") + sql := "SELECT * FROM products WHERE price BETWEEN 10 AND 100" + astNode, err := Parse(sql) + if err != nil { + t.Fatalf("Failed to parse SQL: %v", err) + } + + cols := ExtractColumns(astNode) + if !contains(cols, "price") { + t.Errorf("Expected 'price' in columns, got: %v", cols) + } } func TestExtractFunctions_InCaseExpression(t *testing.T) { - // Skipping - CASE expressions not fully supported yet - t.Skip("CASE expressions not fully supported yet") + sql := "SELECT CASE WHEN COUNT(*) > 0 THEN 'yes' ELSE 'no' END FROM users" + astNode, err := Parse(sql) + if err != nil { + t.Fatalf("Failed to parse SQL: %v", err) + } + + funcs := ExtractFunctions(astNode) + if !contains(funcs, "COUNT") { + t.Errorf("Expected 'COUNT' in functions, got: %v", funcs) + } } func TestExtractTables_WithSetOperations(t *testing.T) { @@ -854,11 +908,27 @@ func TestExtractFunctions_NoFunctions(t *testing.T) { } func TestExtractColumns_WithCastExpression(t *testing.T) { - // Skipping - CAST expressions not fully supported yet - t.Skip("CAST expressions not fully supported yet") + sql := "SELECT CAST(price AS DECIMAL) FROM products" + astNode, err := Parse(sql) + if err != nil { + t.Fatalf("Failed to parse SQL: %v", err) + } + + cols := ExtractColumns(astNode) + if !contains(cols, "price") { + t.Errorf("Expected 'price' in columns, got: %v", cols) + } } -func TestExtractFunctions_ExtractExpression(t *testing.T) { - // Skipping - EXTRACT expressions not fully supported yet - t.Skip("EXTRACT expressions not fully supported yet") +func TestExtractColumns_WithExtractExpression(t *testing.T) { + sql := "SELECT EXTRACT(YEAR FROM created_at) FROM orders" + astNode, err := Parse(sql) + if err != nil { + t.Fatalf("Failed to parse SQL: %v", err) + } + + cols := ExtractColumns(astNode) + if !contains(cols, "created_at") { + t.Errorf("Expected 'created_at' in columns, got: %v", cols) + } } diff --git a/pkg/gosqlx/gosqlx.go b/pkg/gosqlx/gosqlx.go index 7d547935..59a442c7 100644 --- a/pkg/gosqlx/gosqlx.go +++ b/pkg/gosqlx/gosqlx.go @@ -256,15 +256,32 @@ func Validate(sql string) error { // ParseBytes is like Parse but accepts a byte slice. // -// This is useful when you already have SQL as bytes (e.g., from file I/O) -// and want to avoid the string → []byte conversion overhead. +// This avoids the string-to-byte conversion that Parse performs internally, +// making it more efficient when you already have SQL as bytes (e.g., from +// file I/O or network reads). // // Example: // // sqlBytes := []byte("SELECT * FROM users") // astNode, err := gosqlx.ParseBytes(sqlBytes) func ParseBytes(sql []byte) (*ast.AST, error) { - return Parse(string(sql)) + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + tokens, err := tkz.Tokenize(sql) + if err != nil { + return nil, fmt.Errorf("tokenization failed: %w", err) + } + + p := parser.GetParser() + defer parser.PutParser(p) + + astNode, err := p.ParseFromModelTokens(tokens) + if err != nil { + return nil, fmt.Errorf("parsing failed: %w", err) + } + + return astNode, nil } // MustParse is like Parse but panics on error. diff --git a/pkg/sql/ast/ast.go b/pkg/sql/ast/ast.go index f105bc92..64b9fe9d 100644 --- a/pkg/sql/ast/ast.go +++ b/pkg/sql/ast/ast.go @@ -1892,6 +1892,18 @@ func (a AST) Children() []Node { return children } +// HasUnsupportedStatements returns true if the AST contains any +// UnsupportedStatement nodes — statements the parser consumed but +// could not fully model. +func (a AST) HasUnsupportedStatements() bool { + for _, stmt := range a.Statements { + if _, ok := stmt.(*UnsupportedStatement); ok { + return true + } + } + return false +} + // PragmaStatement represents a SQLite PRAGMA statement. // Examples: PRAGMA table_info(users), PRAGMA journal_mode = WAL, PRAGMA integrity_check type PragmaStatement struct { @@ -1924,6 +1936,23 @@ func (d *DescribeStatement) statementNode() {} func (d DescribeStatement) TokenLiteral() string { return "DESCRIBE" } func (d DescribeStatement) Children() []Node { return nil } +// UnsupportedStatement represents a SQL statement that was parsed but not +// fully modeled in the AST. The parser consumed and validated the tokens +// but no dedicated AST node exists yet for this statement kind. +// +// Consumers should use Kind to identify the operation (e.g., "USE", "COPY", +// "CREATE STAGE") and RawSQL for the original text. Tools that do +// switch stmt.(type) should handle this case explicitly rather than +// falling through to a default that assumes the statement is well-structured. +type UnsupportedStatement struct { + Kind string // Operation kind: "USE", "COPY", "PUT", "GET", "LIST", "REMOVE", "CREATE STAGE", etc. + RawSQL string // Original SQL fragment for round-trip fidelity +} + +func (u *UnsupportedStatement) statementNode() {} +func (u UnsupportedStatement) TokenLiteral() string { return u.Kind } +func (u UnsupportedStatement) Children() []Node { return nil } + // ReplaceStatement represents MySQL REPLACE INTO statement type ReplaceStatement struct { TableName string diff --git a/pkg/sql/ast/pool.go b/pkg/sql/ast/pool.go index ecb9f12a..861452c7 100644 --- a/pkg/sql/ast/pool.go +++ b/pkg/sql/ast/pool.go @@ -127,6 +127,12 @@ var ( }, } + unsupportedStmtPool = sync.Pool{ + New: func() interface{} { + return &UnsupportedStatement{} + }, + } + replaceStmtPool = sync.Pool{ New: func() interface{} { return &ReplaceStatement{ @@ -535,6 +541,8 @@ func releaseStatement(stmt Statement) { PutShowStatement(s) case *DescribeStatement: PutDescribeStatement(s) + case *UnsupportedStatement: + PutUnsupportedStatement(s) case *ReplaceStatement: PutReplaceStatement(s) case *AlterStatement: @@ -1756,6 +1764,23 @@ func PutDescribeStatement(stmt *DescribeStatement) { describeStmtPool.Put(stmt) } +// GetUnsupportedStatement gets an UnsupportedStatement from the pool. +func GetUnsupportedStatement() *UnsupportedStatement { + return unsupportedStmtPool.Get().(*UnsupportedStatement) +} + +// PutUnsupportedStatement returns an UnsupportedStatement to the pool. +func PutUnsupportedStatement(stmt *UnsupportedStatement) { + if stmt == nil { + return + } + + stmt.Kind = "" + stmt.RawSQL = "" + + unsupportedStmtPool.Put(stmt) +} + // GetReplaceStatement gets a ReplaceStatement from the pool. func GetReplaceStatement() *ReplaceStatement { stmt := replaceStmtPool.Get().(*ReplaceStatement) diff --git a/pkg/sql/ast/pool_ddl_test.go b/pkg/sql/ast/pool_ddl_test.go index 19d3e1cd..5f691468 100644 --- a/pkg/sql/ast/pool_ddl_test.go +++ b/pkg/sql/ast/pool_ddl_test.go @@ -708,6 +708,55 @@ func TestDescribeStatementPool(t *testing.T) { }) } +// ============================================================ +// UnsupportedStatement pool tests +// ============================================================ + +func TestUnsupportedStatementPool(t *testing.T) { + t.Run("Get returns non-nil", func(t *testing.T) { + stmt := GetUnsupportedStatement() + if stmt == nil { + t.Fatal("GetUnsupportedStatement() returned nil") + } + PutUnsupportedStatement(stmt) + }) + + t.Run("Put nil is safe", func(t *testing.T) { + PutUnsupportedStatement(nil) + }) + + t.Run("Fields zeroed after Put", func(t *testing.T) { + stmt := GetUnsupportedStatement() + stmt.Kind = "COPY" + stmt.RawSQL = "COPY INTO my_table FROM @stage" + + PutUnsupportedStatement(stmt) + + if stmt.Kind != "" { + t.Errorf("Kind not cleared, got %q", stmt.Kind) + } + if stmt.RawSQL != "" { + t.Errorf("RawSQL not cleared, got %q", stmt.RawSQL) + } + }) + + t.Run("Pool roundtrip reuse", func(t *testing.T) { + stmt1 := GetUnsupportedStatement() + stmt1.Kind = "PUT" + stmt1.RawSQL = "PUT file:///tmp/data.csv @stage" + PutUnsupportedStatement(stmt1) + + stmt2 := GetUnsupportedStatement() + if stmt2.Kind != "" { + t.Errorf("Reused statement not clean, Kind=%q", stmt2.Kind) + } + if stmt2.RawSQL != "" { + t.Errorf("Reused statement not clean, RawSQL=%q", stmt2.RawSQL) + } + PutUnsupportedStatement(stmt2) + }) +} + // ============================================================ // ReplaceStatement pool tests // ============================================================ @@ -885,6 +934,11 @@ func TestReleaseASTMixedDMLAndDDL(t *testing.T) { desc.TableName = "users" a.Statements = append(a.Statements, desc) + unsup := GetUnsupportedStatement() + unsup.Kind = "COPY" + unsup.RawSQL = "COPY INTO my_table FROM @stage" + a.Statements = append(a.Statements, unsup) + repl := GetReplaceStatement() repl.TableName = "cache" a.Statements = append(a.Statements, repl) @@ -919,6 +973,7 @@ func TestReleaseStatementsMixedDDL(t *testing.T) { &TruncateStatement{Tables: []string{"t1"}}, &ShowStatement{ShowType: "TABLES"}, &DescribeStatement{TableName: "users"}, + &UnsupportedStatement{Kind: "COPY", RawSQL: "COPY INTO my_table"}, &ReplaceStatement{TableName: "cache"}, &AlterStatement{Name: "r1"}, // DML @@ -1071,6 +1126,7 @@ func BenchmarkMixedDDLReleaseAST(b *testing.B) { GetTruncateStatement(), GetShowStatement(), GetDescribeStatement(), + GetUnsupportedStatement(), GetReplaceStatement(), GetAlterStatement(), ) diff --git a/pkg/sql/monitor/doc.go b/pkg/sql/monitor/doc.go index 1b211bfc..7aeb611a 100644 --- a/pkg/sql/monitor/doc.go +++ b/pkg/sql/monitor/doc.go @@ -12,6 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Deprecated: Package monitor is deprecated in favor of [github.com/ajitpratap0/GoSQLX/pkg/metrics]. +// Use pkg/metrics for all monitoring needs — it provides a superset of monitor's +// functionality with better concurrency (per-field atomics vs global mutex), +// per-error-type tracking, query size distribution, and JSON-serializable output. +// This package will be removed in v2.0. +// // Package monitor provides lightweight performance monitoring for GoSQLX operations. // // This package is a simpler alternative to pkg/metrics, designed for applications diff --git a/pkg/sql/parser/coverage_improvement_test.go b/pkg/sql/parser/coverage_improvement_test.go index f4f6f7b5..0f3f0117 100644 --- a/pkg/sql/parser/coverage_improvement_test.go +++ b/pkg/sql/parser/coverage_improvement_test.go @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +// TODO(v2-cleanup): Coverage-push file. Tests should be moved to ddl_test.go, +// merge_test.go, error_recovery_test.go, and new alter_role/policy/connector +// test files. Then this file should be removed. + package parser import ( diff --git a/pkg/sql/parser/ddl.go b/pkg/sql/parser/ddl.go index 68b4cd7f..bf3ffb50 100644 --- a/pkg/sql/parser/ddl.go +++ b/pkg/sql/parser/ddl.go @@ -98,10 +98,8 @@ func (p *Parser) parseCreateStatement() (ast.Statement, error) { // Snowflake object-type extensions: STAGE, STREAM, TASK, PIPE, FILE FORMAT, // WAREHOUSE, DATABASE, SCHEMA, ROLE, FUNCTION, PROCEDURE, SEQUENCE. - // Parse-only: record the object kind and name on a DescribeStatement - // placeholder, then consume the rest of the statement body permissively - // until ';' or EOF (tracking balanced parens so embedded expressions with - // semicolons inside string literals round-trip). + // Parse-only: consumed permissively and returned as UnsupportedStatement + // until dedicated AST nodes are introduced. if p.dialect == string(keywords.DialectSnowflake) { kind := strings.ToUpper(p.currentToken.Token.Value) if kind == "FILE" && strings.EqualFold(p.peekToken().Token.Value, "FORMAT") { @@ -112,19 +110,28 @@ func (p *Parser) parseCreateStatement() (ast.Statement, error) { case "STAGE", "STREAM", "TASK", "PIPE", "FILE FORMAT", "WAREHOUSE", "DATABASE", "SCHEMA", "ROLE", "SEQUENCE", "FUNCTION", "PROCEDURE": + stmtKind := "CREATE " + kind p.advance() // Consume object-kind keyword + var rawParts []string + rawParts = append(rawParts, "CREATE", kind) // Optional IF NOT EXISTS if p.isType(models.TokenTypeIf) { + rawParts = append(rawParts, "IF") p.advance() if p.isType(models.TokenTypeNot) { + rawParts = append(rawParts, "NOT") p.advance() } if p.isType(models.TokenTypeExists) { + rawParts = append(rawParts, "EXISTS") p.advance() } } // Object name (qualified identifier) name, _ := p.parseQualifiedName() + if name != "" { + rawParts = append(rawParts, name) + } // Consume the rest of the statement body until ';' or EOF, // tracking balanced parens. depth := 0 @@ -136,6 +143,9 @@ func (p *Parser) parseCreateStatement() (ast.Statement, error) { if t == models.TokenTypeSemicolon && depth == 0 { break } + if p.currentToken.Token.Value != "" { + rawParts = append(rawParts, p.currentToken.Token.Value) + } if t == models.TokenTypeLParen { depth++ } else if t == models.TokenTypeRParen { @@ -143,8 +153,9 @@ func (p *Parser) parseCreateStatement() (ast.Statement, error) { } p.advance() } - stub := ast.GetDescribeStatement() - stub.TableName = "CREATE " + kind + " " + name + stub := ast.GetUnsupportedStatement() + stub.Kind = stmtKind + stub.RawSQL = strings.Join(rawParts, " ") return stub, nil } } diff --git a/pkg/sql/parser/expressions_complex.go b/pkg/sql/parser/expressions_complex.go index 19792114..d1e668ac 100644 --- a/pkg/sql/parser/expressions_complex.go +++ b/pkg/sql/parser/expressions_complex.go @@ -379,3 +379,52 @@ func (p *Parser) parseBracketArrayLiteral() (*ast.ArrayConstructorExpression, er return arrayExpr, nil } + +// parseExtractExpression parses EXTRACT(field FROM source). +// +// SQL standard syntax: +// +// EXTRACT(YEAR FROM created_at) +// EXTRACT(MONTH FROM order_date) +// EXTRACT(DOW FROM timestamp_col) +// +// The field is a date/time part keyword (YEAR, MONTH, DAY, HOUR, MINUTE, +// SECOND, DOW, DOY, EPOCH, QUARTER, WEEK, etc.). It is captured as a +// plain string rather than an enum to accommodate dialect extensions. +func (p *Parser) parseExtractExpression() (*ast.ExtractExpression, error) { + p.advance() // Consume EXTRACT + + if !p.isType(models.TokenTypeLParen) { + return nil, p.expectedError("( after EXTRACT") + } + p.advance() // Consume ( + + // Field: typically a keyword like YEAR, MONTH, DAY, etc. + field := strings.ToUpper(p.currentToken.Token.Value) + if field == "" { + return nil, p.expectedError("date/time field (YEAR, MONTH, DAY, etc.)") + } + p.advance() // Consume field + + // FROM keyword + if !p.isType(models.TokenTypeFrom) { + return nil, p.expectedError("FROM after EXTRACT field") + } + p.advance() // Consume FROM + + // Source expression + source, err := p.parseExpression() + if err != nil { + return nil, fmt.Errorf("EXTRACT source: %w", err) + } + + if !p.isType(models.TokenTypeRParen) { + return nil, p.expectedError(") after EXTRACT expression") + } + p.advance() // Consume ) + + return &ast.ExtractExpression{ + Field: field, + Source: source, + }, nil +} diff --git a/pkg/sql/parser/expressions_literal.go b/pkg/sql/parser/expressions_literal.go index ad9b8165..2ef99081 100644 --- a/pkg/sql/parser/expressions_literal.go +++ b/pkg/sql/parser/expressions_literal.go @@ -120,6 +120,14 @@ func (p *Parser) parsePrimaryExpression() (ast.Expression, error) { return funcCall, nil } + // EXTRACT(field FROM source) — SQL standard date/time extraction. + // Tokenized as an identifier; detect by name when followed by '('. + if (p.isType(models.TokenTypeIdentifier) || p.isType(models.TokenTypeKeyword)) && + strings.EqualFold(p.currentToken.Token.Value, "EXTRACT") && + p.peekToken().Token.Type == models.TokenTypeLParen { + return p.parseExtractExpression() + } + if p.isType(models.TokenTypeIdentifier) || p.isType(models.TokenTypeDoubleQuotedString) || ((p.dialect == string(keywords.DialectSQLServer) || p.dialect == string(keywords.DialectClickHouse)) && p.isNonReservedKeyword()) { // Handle identifiers and function calls // Double-quoted strings are treated as identifiers in SQL (e.g., "column_name") diff --git a/pkg/sql/parser/parser.go b/pkg/sql/parser/parser.go index 1c2e6a0f..4c065e71 100644 --- a/pkg/sql/parser/parser.go +++ b/pkg/sql/parser/parser.go @@ -38,6 +38,7 @@ type ConversionResult struct { Tokens []models.TokenWithSpan // Deprecated: PositionMapping is always nil. Position information is now embedded // directly in models.TokenWithSpan.Start and .End fields. + // Scheduled for removal in v2.0. PositionMapping []TokenPosition } @@ -237,9 +238,10 @@ type Parser struct { dialect string // SQL dialect for dialect-aware parsing (default: "postgresql") } -// Deprecated: Parse is provided for backward compatibility only. Use ParseFromModelTokens -// with a []models.TokenWithSpan slice from the tokenizer instead. This shim wraps each -// token.Token into a zero-span models.TokenWithSpan and has no position information. +// Deprecated: Parse is provided for backward compatibility only and is scheduled for +// removal in v2.0. Use ParseFromModelTokens with a []models.TokenWithSpan slice from +// the tokenizer instead. This shim wraps each token.Token into a zero-span +// models.TokenWithSpan and has no position information. // // Parse parses a slice of token.Token into an AST. // @@ -335,7 +337,7 @@ func (p *Parser) ParseFromModelTokens(tokens []models.TokenWithSpan) (*ast.AST, // ParseFromModelTokensWithPositions is identical to ParseFromModelTokens. // Position information is embedded in every models.TokenWithSpan. // -// Deprecated: Use ParseFromModelTokens directly. +// Deprecated: Use ParseFromModelTokens directly. Scheduled for removal in v2.0. func (p *Parser) ParseFromModelTokensWithPositions(tokens []models.TokenWithSpan) (*ast.AST, error) { return p.ParseFromModelTokens(tokens) } @@ -749,21 +751,26 @@ func (p *Parser) parseStatement() (ast.Statement, error) { // USE [WAREHOUSE | DATABASE | SCHEMA | ROLE] // // The object-kind keyword is optional (plain "USE " switches the current -// database). We parse-only; the statement is represented as a DescribeStatement -// placeholder until a dedicated UseStatement node is introduced. +// database). Returned as an UnsupportedStatement until a dedicated UseStatement +// node is introduced. func (p *Parser) parseSnowflakeUseStatement() (ast.Statement, error) { p.advance() // Consume USE + var rawParts []string + rawParts = append(rawParts, "USE") // Optional object kind. switch strings.ToUpper(p.currentToken.Token.Value) { case "WAREHOUSE", "DATABASE", "SCHEMA", "ROLE": + rawParts = append(rawParts, strings.ToUpper(p.currentToken.Token.Value)) p.advance() } name, err := p.parseQualifiedName() if err != nil { return nil, p.expectedError("name after USE") } - stmt := ast.GetDescribeStatement() - stmt.TableName = "USE " + name + rawParts = append(rawParts, name) + stmt := ast.GetUnsupportedStatement() + stmt.Kind = "USE" + stmt.RawSQL = strings.Join(rawParts, " ") return stmt, nil } @@ -776,17 +783,21 @@ func (p *Parser) parseSnowflakeUseStatement() (ast.Statement, error) { // REMOVE @/ // // The statement is consumed token-by-token (tracking balanced parens) until -// ';' or EOF and returned as a DescribeStatement placeholder tagged with the -// operation kind. No AST modeling yet; follow-up work. +// ';' or EOF and returned as an UnsupportedStatement tagged with the +// operation kind. No full AST modeling yet; follow-up work. func (p *Parser) parseSnowflakeStageStatement(kind string) (ast.Statement, error) { p.advance() // Consume leading kind token + var rawParts []string + rawParts = append(rawParts, kind) + // COPY INTO: consume the INTO keyword if present. if kind == "COPY" && p.isType(models.TokenTypeInto) { + rawParts = append(rawParts, "INTO") p.advance() } - // Consume the rest of the statement body. + // Consume the rest of the statement body, capturing tokens for RawSQL. depth := 0 for { t := p.currentToken.Token.Type @@ -796,6 +807,9 @@ func (p *Parser) parseSnowflakeStageStatement(kind string) (ast.Statement, error if t == models.TokenTypeSemicolon && depth == 0 { break } + if p.currentToken.Token.Value != "" { + rawParts = append(rawParts, p.currentToken.Token.Value) + } if t == models.TokenTypeLParen { depth++ } else if t == models.TokenTypeRParen { @@ -803,8 +817,9 @@ func (p *Parser) parseSnowflakeStageStatement(kind string) (ast.Statement, error } p.advance() } - stub := ast.GetDescribeStatement() - stub.TableName = kind + stub := ast.GetUnsupportedStatement() + stub.Kind = kind + stub.RawSQL = strings.Join(rawParts, " ") return stub, nil } diff --git a/pkg/sql/parser/parser_additional_coverage_test.go b/pkg/sql/parser/parser_additional_coverage_test.go index 03936401..cbf3a1ec 100644 --- a/pkg/sql/parser/parser_additional_coverage_test.go +++ b/pkg/sql/parser/parser_additional_coverage_test.go @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +// TODO(v2-cleanup): Coverage-push file with significant overlap against +// operators_test.go, parser_test.go, and ddl_test.go. Consolidate and remove. + package parser import ( diff --git a/pkg/sql/parser/parser_coverage_test.go b/pkg/sql/parser/parser_coverage_test.go index cdc48793..9b76bb85 100644 --- a/pkg/sql/parser/parser_coverage_test.go +++ b/pkg/sql/parser/parser_coverage_test.go @@ -21,8 +21,10 @@ import ( "github.com/ajitpratap0/GoSQLX/pkg/sql/token" ) -// NOTE: CREATE TABLE is not yet implemented in parseStatement() -// Tests for CREATE TABLE are skipped until the feature is implemented +// TODO(v2-cleanup): This file contains coverage-push tests that overlap with +// feature-specific test files (ddl_test.go, window_functions_test.go, +// cte_test.go, operators_test.go, set_operations_test.go). Tests should be +// consolidated into those files and this file removed. See #coverage-consolidation. // TestParser_AlterTable tests ALTER TABLE DDL statement // This covers parseAlterTableStmt, matchToken diff --git a/pkg/sql/parser/parser_final_coverage_test.go b/pkg/sql/parser/parser_final_coverage_test.go index c609df1f..372981e6 100644 --- a/pkg/sql/parser/parser_final_coverage_test.go +++ b/pkg/sql/parser/parser_final_coverage_test.go @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +// TODO(v2-cleanup): Coverage-push file. Tests should be consolidated into +// parser_test.go and error_recovery_test.go, then this file removed. + package parser import ( diff --git a/pkg/sql/parser/parser_targeted_coverage_test.go b/pkg/sql/parser/parser_targeted_coverage_test.go index e4706d69..96f8ea3e 100644 --- a/pkg/sql/parser/parser_targeted_coverage_test.go +++ b/pkg/sql/parser/parser_targeted_coverage_test.go @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +// TODO(v2-cleanup): Coverage-push file. Tests should be consolidated into +// feature-specific files (ddl_test.go, window_functions_test.go, etc.) +// and this file removed. See #coverage-consolidation. + package parser import ( diff --git a/pkg/sql/parser/validate.go b/pkg/sql/parser/validate.go index 2c7d9d4e..b8595b98 100644 --- a/pkg/sql/parser/validate.go +++ b/pkg/sql/parser/validate.go @@ -36,10 +36,10 @@ func Validate(sql string) error { } // ValidateBytes is like Validate but accepts []byte to avoid a string copy. +// Empty or whitespace-only input is rejected as invalid SQL. func ValidateBytes(input []byte) error { - // Fast path: empty/whitespace-only input is valid if len(trimBytes(input)) == 0 { - return nil + return fmt.Errorf("invalid SQL: empty input") } tkz := tokenizer.GetTokenizer() diff --git a/pkg/sql/parser/validate_test.go b/pkg/sql/parser/validate_test.go index 94bba7de..9c8fc1c8 100644 --- a/pkg/sql/parser/validate_test.go +++ b/pkg/sql/parser/validate_test.go @@ -32,7 +32,7 @@ func TestValidate(t *testing.T) { {"select from", "SELECT * FROM users", false}, {"insert", "INSERT INTO t(a) VALUES(1)", false}, {"invalid", "SELECT FROM WHERE", true}, - {"empty", "", false}, + {"empty", "", true}, {"multiple statements", "SELECT 1; SELECT 2", false}, } for _, tt := range tests {