diff --git a/Makefile b/Makefile index fd6706d..34d3376 100644 --- a/Makefile +++ b/Makefile @@ -14,11 +14,13 @@ build: @echo "Building $(BINARY_NAME)..." @mkdir -p $(BUILD_DIR) go build -o $(BUILD_DIR)/$(BINARY_NAME) ./cmd/sqllexer + go build -o $(BUILD_DIR)/$(BINARY_NAME)v2 ./cmd/sqlprocessor # Install the binary to GOPATH/bin install: build @echo "Installing $(BINARY_NAME)..." cp $(BUILD_DIR)/$(BINARY_NAME) $(shell go env GOPATH)/bin/ + cp $(BUILD_DIR)/$(BINARY_NAME)v2 $(shell go env GOPATH)/bin/ # Run tests test: @@ -43,4 +45,4 @@ help: @echo " test - Run tests" @echo " bench - Run benchmarks" @echo " clean - Clean build artifacts" - @echo " help - Show this help message" \ No newline at end of file + @echo " help - Show this help message" diff --git a/cmd/sqlprocessor/main.go b/cmd/sqlprocessor/main.go new file mode 100644 index 0000000..c9520ff --- /dev/null +++ b/cmd/sqlprocessor/main.go @@ -0,0 +1,425 @@ +package main + +import ( + "bufio" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/DataDog/go-sqllexer" +) + +type tokenOut struct { + Type string `json:"type"` + Value string `json:"value"` +} + +type record struct { + Line int `json:"line"` + Query string `json:"query"` + Tokens []tokenOut `json:"tokens"` + HasError bool `json:"has_error,omitempty"` +} + +func main() { + format := flag.String("format", "json", "Output format: json, jsonl or txt") + input := flag.String("input", "", "Input file (optional; can also pass files as args)") + query := flag.String("query", "", "SQL query string input (optional)") + output := flag.String("output", "", "Output file (optional; default stdout for query/stdin)") + outDir := flag.String("outdir", "", "Output directory (default: same as input file)") + includeEmpty := flag.Bool("include-empty", false, "Include empty/whitespace-only lines") + mode := flag.String("mode", "analyze", "Processing mode: analyze or tokenize") + flag.Parse() + + inputs := make([]string, 0, 1+len(flag.Args())) + if *input != "" { + inputs = append(inputs, *input) + } + inputs = append(inputs, flag.Args()...) + + if *format != "json" && *format != "txt" && *format != "jsonl" { + fmt.Fprintf(os.Stderr, "Invalid -format %q (expected json, jsonl, or txt)\n", *format) + os.Exit(2) + } + + if *mode != "analyze" && *mode != "tokenize" { + fmt.Fprintf(os.Stderr, "Invalid -mode %q (expected analyze or tokenize)\n", *mode) + os.Exit(2) + } + + if *output != "" && *outDir != "" { + fmt.Fprintln(os.Stderr, "Cannot use -output with -outdir") + os.Exit(2) + } + + if *query != "" && len(inputs) > 0 { + fmt.Fprintln(os.Stderr, "Cannot use -query with file inputs") + os.Exit(2) + } + + if *output != "" && len(inputs) > 1 { + fmt.Fprintln(os.Stderr, "-output only supports a single input file or -query/stdin") + os.Exit(2) + } + + if *outDir != "" { + if err := os.MkdirAll(*outDir, 0o755); err != nil { + fmt.Fprintf(os.Stderr, "Failed to create outdir %q: %v\n", *outDir, err) + os.Exit(1) + } + } + + exitCode := 0 + if *query != "" { + if err := processReaderToPath(strings.NewReader(*query), *format, *output, *includeEmpty, *mode); err != nil { + fmt.Fprintf(os.Stderr, "Error processing query: %v\n", err) + exitCode = 1 + } + os.Exit(exitCode) + } + + if len(inputs) == 0 { + if err := processReaderToPath(os.Stdin, *format, *output, *includeEmpty, *mode); err != nil { + fmt.Fprintf(os.Stderr, "Error processing stdin: %v\n", err) + exitCode = 1 + } + os.Exit(exitCode) + } + + for _, path := range inputs { + if err := processFile(path, *format, *outDir, *output, *includeEmpty, *mode); err != nil { + fmt.Fprintf(os.Stderr, "Error processing %s: %v\n", path, err) + exitCode = 1 + } + } + + os.Exit(exitCode) +} + +func processFile(path, format, outDir, output string, includeEmpty bool, mode string) error { + inFile, err := os.Open(path) + if err != nil { + return err + } + defer inFile.Close() + + var out io.Writer + if output != "" { + outFile, err := os.Create(output) + if err != nil { + return err + } + defer outFile.Close() + out = outFile + } else { + outPath, err := outputPath(path, outDir, format) + if err != nil { + return err + } + outFile, err := os.Create(outPath) + if err != nil { + return err + } + defer outFile.Close() + out = outFile + } + + return processReader(inFile, format, out, includeEmpty, mode) +} + +func outputPath(inputPath, outDir, format string) (string, error) { + base := filepath.Base(inputPath) + base = strings.TrimSuffix(base, filepath.Ext(base)) + + suffix := ".tokens.json" + switch format { + case "txt": + suffix = ".tokens.txt" + case "jsonl": + suffix = ".tokens.jsonl" + } + filename := base + suffix + + if outDir == "" { + return filepath.Join(filepath.Dir(inputPath), filename), nil + } + return filepath.Join(outDir, filename), nil +} + +func processReaderToPath(r io.Reader, format, output string, includeEmpty bool, mode string) error { + if output == "" { + return processReader(r, format, os.Stdout, includeEmpty, mode) + } + outFile, err := os.Create(output) + if err != nil { + return err + } + defer outFile.Close() + return processReader(r, format, outFile, includeEmpty, mode) +} + +func processReader(r io.Reader, format string, out io.Writer, includeEmpty bool, mode string) error { + reader := bufio.NewReader(r) + lineNum := 0 + + switch format { + case "json": + if _, err := out.Write([]byte("[\n")); err != nil { + return err + } + first := true + for { + line, err := readLine(reader) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return err + } + lineNum++ + if !includeEmpty && strings.TrimSpace(line) == "" { + continue + } + + var v any + if mode == "tokenize" { + v = tokenizeLineTypesOnly(line) + } else { + v = tokenizeLine(line, lineNum) + } + if !first { + if _, err := out.Write([]byte(",\n")); err != nil { + return err + } + } + first = false + + blob, err := json.Marshal(v) + if err != nil { + return err + } + if _, err := out.Write(blob); err != nil { + return err + } + } + if _, err := out.Write([]byte("\n]\n")); err != nil { + return err + } + case "jsonl": + writer := bufio.NewWriter(out) + for { + line, err := readLine(reader) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return err + } + lineNum++ + if !includeEmpty && strings.TrimSpace(line) == "" { + continue + } + + var v any + if mode == "tokenize" { + v = tokenizeLineTypesOnly(line) + } else { + v = tokenizeLine(line, lineNum) + } + blob, err := json.Marshal(v) + if err != nil { + return err + } + blob = append(blob, '\n') + if _, err := writer.Write(blob); err != nil { + return err + } + } + if err := writer.Flush(); err != nil { + return err + } + case "txt": + writer := bufio.NewWriter(out) + for { + line, err := readLine(reader) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return err + } + lineNum++ + if !includeEmpty && strings.TrimSpace(line) == "" { + continue + } + + if mode == "tokenize" { + if _, err := fmt.Fprintln(writer, tokenizeLineTypesOnly(line)); err != nil { + return err + } + } else { + rec := tokenizeLine(line, lineNum) + if err := writeTxtRecord(writer, rec); err != nil { + return err + } + } + } + if err := writer.Flush(); err != nil { + return err + } + default: + return fmt.Errorf("unsupported format: %s", format) + } + + return nil +} + +func tokenizeLineTypesOnly(line string) string { + lexer := sqllexer.New(line) + var types []string + for { + tok := lexer.Scan() + if tok.Type == sqllexer.EOF { + break + } + types = append(types, tokenTypeName(tok.Type)) + } + return strings.Join(types, " ") +} + +func tokenizeLine(line string, lineNum int) record { + lexer := sqllexer.New(line) + + tokens := make([]tokenOut, 0, 32) + hasError := false + for { + tok := lexer.Scan() + if tok.Type == sqllexer.EOF { + break + } + if tok.Type == sqllexer.ERROR { + hasError = true + } + tokens = append(tokens, tokenOut{ + Type: tokenTypeName(tok.Type), + Value: tok.Value, + }) + } + + return record{ + Line: lineNum, + Query: line, + Tokens: tokens, + HasError: hasError, + } +} + +func writeTxtRecord(w io.Writer, rec record) error { + parts := make([]string, 0, len(rec.Tokens)) + for _, tok := range rec.Tokens { + parts = append(parts, tok.Type+":"+strconv.Quote(tok.Value)) + } + line := strings.Join(parts, " ") + if rec.HasError { + line = line + " ERROR:true" + } + _, err := fmt.Fprintf(w, "%d\t%s\n", rec.Line, line) + return err +} + +func tokenTypeName(t sqllexer.TokenType) string { + switch t { + case sqllexer.ERROR: + return "ERROR" + case sqllexer.EOF: + return "EOF" + case sqllexer.SPACE: + return "SPACE" + case sqllexer.STRING: + return "STRING" + case sqllexer.INCOMPLETE_STRING: + return "INCOMPLETE_STRING" + case sqllexer.NUMBER: + return "NUMBER" + case sqllexer.IDENT: + return "IDENT" + case sqllexer.QUOTED_IDENT: + return "QUOTED_IDENT" + case sqllexer.OPERATOR: + return "OPERATOR" + case sqllexer.WILDCARD: + return "WILDCARD" + case sqllexer.COMMENT: + return "COMMENT" + case sqllexer.MULTILINE_COMMENT: + return "MULTILINE_COMMENT" + case sqllexer.PUNCTUATION: + return "PUNCTUATION" + case sqllexer.DOLLAR_QUOTED_FUNCTION: + return "DOLLAR_QUOTED_FUNCTION" + case sqllexer.DOLLAR_QUOTED_STRING: + return "DOLLAR_QUOTED_STRING" + case sqllexer.POSITIONAL_PARAMETER: + return "POSITIONAL_PARAMETER" + case sqllexer.BIND_PARAMETER: + return "BIND_PARAMETER" + case sqllexer.FUNCTION: + return "FUNCTION" + case sqllexer.SYSTEM_VARIABLE: + return "SYSTEM_VARIABLE" + case sqllexer.UNKNOWN: + return "UNKNOWN" + case sqllexer.COMMAND: + return "COMMAND" + case sqllexer.KEYWORD: + return "KEYWORD" + case sqllexer.JSON_OP: + return "JSON_OP" + case sqllexer.BOOLEAN: + return "BOOLEAN" + case sqllexer.NULL: + return "NULL" + case sqllexer.PROC_INDICATOR: + return "PROC_INDICATOR" + case sqllexer.CTE_INDICATOR: + return "CTE_INDICATOR" + case sqllexer.ALIAS_INDICATOR: + return "ALIAS_INDICATOR" + default: + return fmt.Sprintf("TokenType(%d)", int(t)) + } +} + +func readLine(reader *bufio.Reader) (string, error) { + var buf []byte + for { + chunk, err := reader.ReadSlice('\n') + buf = append(buf, chunk...) + if err == nil { + break + } + if errors.Is(err, bufio.ErrBufferFull) { + continue + } + if errors.Is(err, io.EOF) { + if len(buf) == 0 { + return "", io.EOF + } + break + } + return "", err + } + + line := string(buf) + line = strings.TrimSuffix(line, "\n") + line = strings.TrimSuffix(line, "\r") + line = strings.ReplaceAll(line, "'", "") + line = strings.ReplaceAll(line, "\"", "") + return line, nil +} diff --git a/data/benchmark.jsonl b/data/benchmark.jsonl new file mode 100644 index 0000000..9c17016 --- /dev/null +++ b/data/benchmark.jsonl @@ -0,0 +1,22 @@ +{"text": "\u05d1\u05e1\u05d9\u05e1 \u05de\u05d9\u05d8\u05d4 violet", "label": 0} +{"text": "2024-10-07t16:04:06.184z", "label": 0} +{"text": "1') where 7334=7334 rlike sleep(5)-- clhm", "label": 1} +{"text": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", "label": 0} +{"text": "1111 ---- { selection de basket adidas }", "label": 0} +{"text": "1---- basket adidas", "label": 0} +{"text": "-3277\")) AS aJRX WHERE 2435=2435 ORDER BY 1#", "label": 1} +{"text": "1/**/aND/**/987222222222222=55762222222#/**/Rrhc", "label": 1} +{"text": "1\" OR \"1\"=\"1\" union select toto,foobar from tolrol where xxx==\"a", "label": 1} +{"text": "1\" OR \"1\"=\"1\" union select /**/ toto,foobar from tolrol where xxx==\"a", "label": 1} +{"text": "1\" OR \"1\"=\"1\" union select login,password from users where xxx==\"a", "label": 1} +{"text": "1 unIoN/**/SELECT /**/ foo /**//**/ FROM bar", "label": 1} +{"text": "1 unIoN SELECT foo /**//**/ FROM bar", "label": 1} +{"text": "1 unIoN SELECT foo FROM bar", "label": 1} +{"text": "1 unIoN SELECT login,password from users", "label": 1} +{"text": "1\" OR \"1\"=\"1\" unIoN SELECT login,password from users where toto=\"a", "label": 1} +{"text": "127893789719823987123\" OR \"42\"=\"42", "label": 1} +{"text": "1 unIoN/**/SELECT /**/ foo /**//**/ FROM bar", "label": 1} +{"text": "1 unIoN /**/ SELECT /**/ foo /**//**/ FROM bar", "label": 1} +{"text": "1 unIoN/**/SELECT/**/foo/**//**/FROM bar", "label": 1} +{"text": "1 unIoN/**/SELECT/**/foo/**/FROM/**/bar", "label": 1} + diff --git a/sqllexer.go b/sqllexer.go index f19e3c5..6d62f45 100644 --- a/sqllexer.go +++ b/sqllexer.go @@ -352,7 +352,7 @@ func (s *Lexer) scanIdentifier(ch rune) *Token { // If first character is Unicode, skip trie lookup if ch > 127 { - for isIdentifier(ch) { + for isIdentifier(ch, s.lookAhead(1)) { s.hasDigits = s.hasDigits || isDigit(ch) ch = s.nextBy(utf8.RuneLen(ch)) } @@ -395,14 +395,14 @@ func (s *Lexer) scanIdentifier(ch rune) *Token { } // If we found a complete keyword and next char is whitespace - if node.isEnd && (isPunctuation(ch) || isSpace(ch) || isEOF(ch)) { + if node.isEnd && (isPunctuation(ch) || isSpace(ch) || isMultiLineComment(ch, s.lookAhead(1)) || isEOF(ch)) { s.cursor = pos + 1 // Include the last matched character s.isTableIndicator = node.isTableIndicator return s.emit(node.tokenType) } // Continue scanning identifier if no keyword match - for isIdentifier(ch) { + for isIdentifier(ch, s.lookAhead(1)) { s.hasDigits = s.hasDigits || isDigit(ch) ch = s.nextBy(utf8.RuneLen(ch)) } diff --git a/sqllexer_test.go b/sqllexer_test.go index 43d984d..9237866 100644 --- a/sqllexer_test.go +++ b/sqllexer_test.go @@ -1027,6 +1027,32 @@ here */`, {STRING, `'\'`}, }, }, + { + name: "simple select with multiline comments as separators", + input: `SELECT/**/*/**/FROM/**/test`, + expected: []TokenSpec{ + {COMMAND, "SELECT"}, + {MULTILINE_COMMENT, "/**/"}, + {WILDCARD, "*"}, + {MULTILINE_COMMENT, "/**/"}, + {KEYWORD, "FROM"}, + {MULTILINE_COMMENT, "/**/"}, + {IDENT, "test"}, + }, + }, + { + name: "simple select with multiline comments as separators", + input: `SELECT/**/foo/**/FROM/**/test`, + expected: []TokenSpec{ + {COMMAND, "SELECT"}, + {MULTILINE_COMMENT, "/**/"}, + {IDENT, "foo"}, + {MULTILINE_COMMENT, "/**/"}, + {KEYWORD, "FROM"}, + {MULTILINE_COMMENT, "/**/"}, + {IDENT, "test"}, + }, + }, } for _, tt := range tests { diff --git a/sqllexer_utils.go b/sqllexer_utils.go index 7775e87..b49fff8 100644 --- a/sqllexer_utils.go +++ b/sqllexer_utils.go @@ -363,8 +363,8 @@ func isEOF(ch rune) bool { } // isIdentifier checks if a rune is an identifier -func isIdentifier(ch rune) bool { - return ch == '"' || ch == '.' || ch == '?' || ch == '$' || ch == '#' || ch == '/' || ch == '@' || ch == '!' || isLetter(ch) || isDigit(ch) +func isIdentifier(ch rune, nextCh rune) bool { + return ch == '"' || ch == '.' || ch == '?' || ch == '$' || ch == '#' || (ch == '/' && !isMultiLineComment(ch, nextCh)) || ch == '@' || ch == '!' || isLetter(ch) || isDigit(ch) } // isValueToken checks if a token is a value token