Skip to content

Commit 29b42e5

Browse files
committed
[SPARK-48338][SQL] spark-sql cli correctly handles SQL Scripting compound blocks
1 parent c55ff8f commit 29b42e5

2 files changed

Lines changed: 436 additions & 12 deletions

File tree

sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala

Lines changed: 174 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.io._
2121
import java.util.{ArrayList => JArrayList, List => JList, Locale}
2222
import java.util.concurrent.TimeUnit
2323

24+
import scala.collection.mutable
2425
import scala.jdk.CollectionConverters._
2526

2627
import jline.console.ConsoleReader
@@ -268,9 +269,19 @@ private[hive] object SparkSQLCLIDriver extends Logging {
268269

269270
if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) {
270271
line = prefix + line
271-
ret = cli.processLine(line, true)
272-
prefix = ""
273-
currentPrompt = promptWithCurrentDB
272+
// For SQL Scripting, a semicolon inside a compound block (BEGIN...END,
273+
// IF...END IF, WHILE...DO...END WHILE, etc.) terminates a statement *within*
274+
// the block, not the block itself. Only fire when the accumulated input
275+
// is no longer inside any open scripting block.
276+
val insideScriptingBlock = sqlScriptingBlockDepth(line) > 0
277+
if (insideScriptingBlock) {
278+
prefix = line
279+
currentPrompt = continuedPromptWithDBSpaces
280+
} else {
281+
ret = cli.processLine(line, allowInterrupting = true)
282+
prefix = ""
283+
currentPrompt = promptWithCurrentDB
284+
}
274285
} else {
275286
prefix = prefix + line
276287
currentPrompt = continuedPromptWithDBSpaces
@@ -374,6 +385,132 @@ private[hive] object SparkSQLCLIDriver extends Logging {
374385
}
375386
Array[Completer](propCompleter, customCompleter)
376387
}
388+
389+
private final val END_SUFFIX_KEYWORDS =
390+
Set("IF", "CASE", "FOR", "WHILE", "LOOP", "REPEAT")
391+
392+
/**
393+
* Tracks SQL Scripting block nesting depth by processing SQL keyword tokens.
394+
*
395+
* Call [[processChar]] for each character that is outside quoted strings and comments.
396+
* Call [[clearWordBuffer]] when entering a quoted string or comment mid-word.
397+
* Call [[flush]] after the last character to finalize any pending token.
398+
* Read [[depth]] to obtain the current block nesting level.
399+
*
400+
* Depth tracking rules:
401+
* - BEGIN -> depth++
402+
* - CASE / IF (not `IF(`) / DO / LOOP / REPEAT -> depth++ only when depth > 0
403+
* - Every END keyword -> depth-- (when depth > 0)
404+
* - Keywords immediately following END (IF, CASE, FOR, WHILE, LOOP, REPEAT) are
405+
* decorative suffixes (e.g. "END IF") and do NOT change the depth.
406+
*/
407+
private[hive] class SqlScriptBlockTracker {
408+
var depth: Int = 0
409+
private val wordBuf = new StringBuilder()
410+
private var prevWord: String = ""
411+
412+
private def isWordChar(c: Char): Boolean = c.isLetterOrDigit || c == '_'
413+
414+
/** Feed a single character that is known to be outside any quote or comment. */
415+
def processChar(c: Char): Unit = {
416+
if (isWordChar(c)) {
417+
wordBuf.append(c)
418+
} else if (wordBuf.nonEmpty) {
419+
onWordEnd(wordBuf.toString, c)
420+
wordBuf.clear()
421+
}
422+
}
423+
424+
/** Discard a partially accumulated word (e.g. when entering a quote/comment). */
425+
def clearWordBuffer(): Unit = wordBuf.clear()
426+
427+
/** Flush any remaining word at the end of the input. */
428+
def flush(): Unit = {
429+
if (wordBuf.nonEmpty) {
430+
onWordEnd(wordBuf.toString, ' ')
431+
wordBuf.clear()
432+
}
433+
}
434+
435+
private def onWordEnd(word: String, nextChar: Char): Unit = {
436+
val upper = word.toUpperCase(Locale.ROOT)
437+
if (prevWord == "END" && END_SUFFIX_KEYWORDS.contains(upper)) {
438+
// Decorative suffix after END (e.g. "END IF") -- do not change depth.
439+
prevWord = upper
440+
return
441+
}
442+
prevWord = upper
443+
upper match {
444+
case "BEGIN" => depth += 1
445+
case "END" if depth > 0 => depth -= 1
446+
case "CASE" if depth > 0 => depth += 1
447+
case "IF" if depth > 0 && nextChar != '(' => depth += 1
448+
case "DO" if depth > 0 => depth += 1
449+
case "LOOP" if depth > 0 => depth += 1
450+
case "REPEAT" if depth > 0 => depth += 1
451+
case _ =>
452+
}
453+
}
454+
}
455+
456+
/**
457+
* Computes the SQL scripting block depth of the given SQL text.
458+
* Returns 0 when the text is not inside any scripting block, > 0 when still open.
459+
*/
460+
private[hive] def sqlScriptingBlockDepth(text: String): Int = {
461+
var insideSingleQuote = false
462+
var insideDoubleQuote = false
463+
var insideSimpleComment = false
464+
var bracketedCommentLevel = 0
465+
var escape = false
466+
var leavingBracketedComment = false
467+
val tracker = new SqlScriptBlockTracker()
468+
469+
def insideBracketedComment: Boolean = bracketedCommentLevel > 0
470+
def insideComment: Boolean = insideSimpleComment || insideBracketedComment
471+
def insideAnyQuote: Boolean = insideSingleQuote || insideDoubleQuote
472+
473+
for (index <- 0 until text.length) {
474+
if (leavingBracketedComment) {
475+
bracketedCommentLevel -= 1
476+
leavingBracketedComment = false
477+
}
478+
479+
val c = text.charAt(index)
480+
481+
if (!insideComment && !insideAnyQuote) {
482+
tracker.processChar(c)
483+
} else if (tracker.depth >= 0) {
484+
// Suppress partial tokens inside quotes/comments.
485+
tracker.clearWordBuffer()
486+
}
487+
488+
if (c == '\'' && !insideComment) {
489+
if (!escape && !insideDoubleQuote) insideSingleQuote = !insideSingleQuote
490+
} else if (c == '\"' && !insideComment) {
491+
if (!escape && !insideSingleQuote) insideDoubleQuote = !insideDoubleQuote
492+
} else if (c == '-') {
493+
val hasNext = index + 1 < text.length
494+
if (!insideAnyQuote && !insideComment && hasNext && text.charAt(index + 1) == '-') {
495+
insideSimpleComment = true
496+
}
497+
} else if (c == '\n' && !escape) {
498+
insideSimpleComment = false
499+
} else if (c == '/' && !insideSimpleComment && !insideAnyQuote) {
500+
if (insideBracketedComment && index > 0 && text.charAt(index - 1) == '*') {
501+
leavingBracketedComment = true
502+
} else if (index + 1 < text.length && text.charAt(index + 1) == '*') {
503+
bracketedCommentLevel += 1
504+
}
505+
}
506+
507+
if (escape) escape = false
508+
else if (c == '\\') escape = true
509+
}
510+
511+
tracker.flush()
512+
tracker.depth
513+
}
377514
}
378515

379516
private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
@@ -583,7 +720,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
583720
var lastRet: Int = 0
584721

585722
// we can not use "split" function directly as ";" may be quoted
586-
val commands = splitSemiColon(line).asScala
723+
val commands = splitSemiColon(line)
587724
var command: String = ""
588725
for (oneCmd <- commands) {
589726
if (oneCmd.endsWith("\\")) {
@@ -618,7 +755,12 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
618755
// string, the origin implementation from Hive will not drop the trailing semicolon as expected,
619756
// hence we refined this function a little bit.
620757
// Note: [SPARK-33100] Ignore a semicolon inside a bracketed comment in spark-sql.
621-
private[hive] def splitSemiColon(line: String): JList[String] = {
758+
// Note: For SQL Scripting, semicolons inside a SQL Scripting compound block (BEGIN...END,
759+
// IF...END IF, WHILE/FOR...DO...END WHILE/FOR, LOOP...END LOOP, REPEAT...END REPEAT,
760+
// CASE...END CASE, and nested/labeled variants) terminate individual statements
761+
// *within* the block and must not be used as split points. Block depth is tracked
762+
// with the same keyword-aware scanner used by the interactive input loop.
763+
private[hive] def splitSemiColon(line: String): Array[String] = {
622764
var insideSingleQuote = false
623765
var insideDoubleQuote = false
624766
var insideSimpleComment = false
@@ -627,12 +769,15 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
627769
var beginIndex = 0
628770
var leavingBracketedComment = false
629771
var isStatement = false
630-
val ret = new JArrayList[String]
772+
val ret = mutable.ArrayBuilder.make[String]
773+
774+
// For SQL Scripting block-depth tracking.
775+
val tracker = new SparkSQLCLIDriver.SqlScriptBlockTracker()
631776

632777
def insideBracketedComment: Boolean = bracketedCommentLevel > 0
633778
def insideComment: Boolean = insideSimpleComment || insideBracketedComment
634779
def statementInProgress(index: Int): Boolean = isStatement || (!insideComment &&
635-
index > beginIndex && !s"${line.charAt(index)}".trim.isEmpty)
780+
index > beginIndex && s"${line.charAt(index)}".trim.nonEmpty)
636781

637782
for (index <- 0 until line.length) {
638783
// Checks if we need to decrement a bracketed comment level; the last character '/' of
@@ -643,6 +788,18 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
643788
leavingBracketedComment = false
644789
}
645790

791+
val c = line.charAt(index)
792+
793+
// Accumulate keyword tokens for SQL Scripting block-depth tracking.
794+
// The tracker is updated *before* the quote/comment state below is toggled,
795+
// so that a closing quote or the start of a `--` comment correctly flushes any
796+
// in-progress token first.
797+
if (!insideComment && !insideSingleQuote && !insideDoubleQuote) {
798+
tracker.processChar(c)
799+
} else {
800+
tracker.clearWordBuffer()
801+
}
802+
646803
if (line.charAt(index) == '\'' && !insideComment) {
647804
// take a look to see if it is escaped
648805
// See the comment above about SPARK-31595
@@ -671,10 +828,13 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
671828
} else if (line.charAt(index) == ';') {
672829
if (insideSingleQuote || insideDoubleQuote || insideComment) {
673830
// do not split
831+
} else if (tracker.depth > 0) {
832+
// do not split: this semicolon is a statement terminator inside a SQL Scripting
833+
// compound block, not a boundary between top-level commands.
674834
} else {
675835
if (isStatement) {
676836
// split, do not include ; itself
677-
ret.add(line.substring(beginIndex, index))
837+
ret += line.substring(beginIndex, index)
678838
}
679839
beginIndex = index + 1
680840
isStatement = false
@@ -704,6 +864,10 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
704864

705865
isStatement = statementInProgress(index)
706866
}
867+
868+
// Flush any word that ends at the very last character of the input.
869+
tracker.flush()
870+
707871
// Check the last char is end of nested bracketed comment.
708872
val endOfBracketedComment = leavingBracketedComment && bracketedCommentLevel == 1
709873
// Spark SQL support simple comment and nested bracketed comment in query body.
@@ -715,8 +879,8 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
715879
// CLI should also pass this part to the backend engine, which may throw an exception
716880
// with clear error message.
717881
if (!endOfBracketedComment && (isStatement || insideBracketedComment)) {
718-
ret.add(line.substring(beginIndex))
882+
ret += line.substring(beginIndex)
719883
}
720-
ret
884+
ret.result()
721885
}
722886
}

0 commit comments

Comments
 (0)