Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io._
import java.util.{ArrayList => JArrayList, List => JList, Locale}
import java.util.concurrent.TimeUnit

import scala.collection.mutable
import scala.jdk.CollectionConverters._

import jline.console.ConsoleReader
Expand Down Expand Up @@ -268,9 +269,19 @@ private[hive] object SparkSQLCLIDriver extends Logging {

if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) {
line = prefix + line
ret = cli.processLine(line, true)
prefix = ""
currentPrompt = promptWithCurrentDB
// For SQL Scripting, a semicolon inside a compound block (BEGIN ... END,
// IF ... END IF, WHILE ... DO ... END WHILE, etc.) terminates a statement
// *within* the block, not the block itself. Only fire when the accumulated
// input is no longer inside any open scripting block.
val insideScriptingBlock = sqlScriptingBlockDepth(line) > 0
if (insideScriptingBlock) {
prefix = line
currentPrompt = continuedPromptWithDBSpaces
} else {
ret = cli.processLine(line, allowInterrupting = true)
prefix = ""
currentPrompt = promptWithCurrentDB
}
} else {
prefix = prefix + line
currentPrompt = continuedPromptWithDBSpaces
Expand Down Expand Up @@ -374,6 +385,131 @@ private[hive] object SparkSQLCLIDriver extends Logging {
}
Array[Completer](propCompleter, customCompleter)
}

private final val END_SUFFIX_KEYWORDS =
Set("IF", "CASE", "FOR", "WHILE", "LOOP", "REPEAT")

/**
* Tracks SQL Scripting block nesting depth by processing SQL keyword tokens.
*
* Call [[processChar]] for each character that is outside quoted strings and comments.
* Call [[clearWordBuffer]] when entering a quoted string or comment mid-word.
* Call [[flush]] after the last character to finalize any pending token.
* Read [[depth]] to obtain the current block nesting level.
*
* Depth tracking rules:
* - BEGIN -> depth++
* - CASE / IF (not `IF(`) / DO / LOOP / REPEAT -> depth++ only when depth > 0
* - Every END keyword -> depth-- (when depth > 0)
* - Keywords immediately following END (IF, CASE, FOR, WHILE, LOOP, REPEAT) are
* decorative suffixes (e.g. "END IF") and do NOT change the depth.
*/
private[hive] class SqlScriptBlockTracker {
var depth: Int = 0
private val wordBuf = new StringBuilder()
private var prevWord: String = ""

private def isWordChar(c: Char): Boolean = c.isLetterOrDigit || c == '_'

/** Feed a single character that is known to be outside any quote or comment. */
def processChar(c: Char): Unit = {
if (isWordChar(c)) {
wordBuf.append(c)
} else if (wordBuf.nonEmpty) {
onWordEnd(wordBuf.toString, c)
wordBuf.clear()
}
}

/** Discard a partially accumulated word (e.g. when entering a quote/comment). */
def clearWordBuffer(): Unit = wordBuf.clear()

/** Flush any remaining word at the end of the input. */
def flush(): Unit = {
if (wordBuf.nonEmpty) {
onWordEnd(wordBuf.toString, ' ')
wordBuf.clear()
}
}

private def onWordEnd(word: String, nextChar: Char): Unit = {
val upper = word.toUpperCase(Locale.ROOT)
if (prevWord == "END" && END_SUFFIX_KEYWORDS.contains(upper)) {
// Decorative suffix after END (e.g. "END IF") -- do not change depth.
prevWord = upper
return
}
prevWord = upper
upper match {
case "BEGIN" => depth += 1
case "END" if depth > 0 => depth -= 1
case "CASE" if depth > 0 => depth += 1
case "IF" if depth > 0 && nextChar != '(' => depth += 1
case "DO" if depth > 0 => depth += 1
case "LOOP" if depth > 0 => depth += 1
case "REPEAT" if depth > 0 => depth += 1
case _ =>
}
}
}

/**
* Computes the SQL scripting block depth of the given SQL text.
* Returns 0 when the text is not inside any scripting block, > 0 when still open.
*/
private[hive] def sqlScriptingBlockDepth(text: String): Int = {
var insideSingleQuote = false
var insideDoubleQuote = false
var insideSimpleComment = false
var bracketedCommentLevel = 0
var escape = false
var leavingBracketedComment = false
val tracker = new SqlScriptBlockTracker()

def insideBracketedComment: Boolean = bracketedCommentLevel > 0
def insideComment: Boolean = insideSimpleComment || insideBracketedComment
def insideAnyQuote: Boolean = insideSingleQuote || insideDoubleQuote

for (index <- 0 until text.length) {
if (leavingBracketedComment) {
bracketedCommentLevel -= 1
leavingBracketedComment = false
}

val c = text.charAt(index)

if (!insideComment && !insideAnyQuote) {
tracker.processChar(c)
} else if (tracker.depth >= 0) {
// Suppress partial tokens inside quotes/comments.
tracker.clearWordBuffer()
}

if (c == '\'' && !insideComment) {
if (!escape && !insideDoubleQuote) insideSingleQuote = !insideSingleQuote
} else if (c == '\"' && !insideComment) {
if (!escape && !insideSingleQuote) insideDoubleQuote = !insideDoubleQuote
} else if (c == '-') {
val hasNext = index + 1 < text.length
if (!insideAnyQuote && !insideComment && hasNext && text.charAt(index + 1) == '-') {
insideSimpleComment = true
}
} else if (c == '\n' && !escape) {
insideSimpleComment = false
} else if (c == '/' && !insideSimpleComment && !insideAnyQuote) {
if (insideBracketedComment && index > 0 && text.charAt(index - 1) == '*') {
leavingBracketedComment = true
} else if (index + 1 < text.length && text.charAt(index + 1) == '*') {
bracketedCommentLevel += 1
}
}

if (escape) escape = false; else if (c == '\\') escape = true
}

tracker.flush()
tracker.depth
}
}

private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
Expand Down Expand Up @@ -583,7 +719,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
var lastRet: Int = 0

// we can not use "split" function directly as ";" may be quoted
val commands = splitSemiColon(line).asScala
val commands = splitSemiColon(line)
var command: String = ""
for (oneCmd <- commands) {
if (oneCmd.endsWith("\\")) {
Expand Down Expand Up @@ -618,7 +754,12 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
// string, the origin implementation from Hive will not drop the trailing semicolon as expected,
// hence we refined this function a little bit.
// Note: [SPARK-33100] Ignore a semicolon inside a bracketed comment in spark-sql.
private[hive] def splitSemiColon(line: String): JList[String] = {
// Note: [SPARK-56147] Semicolons inside a SQL Scripting compound block (BEGIN ... END,
// IF ... END IF, WHILE/FOR ... DO ... END WHILE/FOR, LOOP ... END LOOP, REPEAT ... END REPEAT,
// CASE ... END CASE, and nested/labeled variants) terminate individual statements
// *within* the block and must not be used as split points. Block depth is tracked
// with the same keyword-aware scanner used by the interactive input loop.
private[hive] def splitSemiColon(line: String): Array[String] = {
var insideSingleQuote = false
var insideDoubleQuote = false
var insideSimpleComment = false
Expand All @@ -627,12 +768,15 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
var beginIndex = 0
var leavingBracketedComment = false
var isStatement = false
val ret = new JArrayList[String]
val ret = mutable.ArrayBuilder.make[String]

// For SQL Scripting block-depth tracking.
val tracker = new SparkSQLCLIDriver.SqlScriptBlockTracker()

def insideBracketedComment: Boolean = bracketedCommentLevel > 0
def insideComment: Boolean = insideSimpleComment || insideBracketedComment
def statementInProgress(index: Int): Boolean = isStatement || (!insideComment &&
index > beginIndex && !s"${line.charAt(index)}".trim.isEmpty)
index > beginIndex && s"${line.charAt(index)}".trim.nonEmpty)

for (index <- 0 until line.length) {
// Checks if we need to decrement a bracketed comment level; the last character '/' of
Expand All @@ -643,6 +787,18 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
leavingBracketedComment = false
}

val c = line.charAt(index)

// Accumulate keyword tokens for SQL Scripting block-depth tracking.
// The tracker is updated *before* the quote/comment state below is toggled,
// so that a closing quote or the start of a `--` comment correctly flushes any
// in-progress token first.
if (!insideComment && !insideSingleQuote && !insideDoubleQuote) {
tracker.processChar(c)
} else {
tracker.clearWordBuffer()
}

if (line.charAt(index) == '\'' && !insideComment) {
// take a look to see if it is escaped
// See the comment above about SPARK-31595
Expand Down Expand Up @@ -671,10 +827,13 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
} else if (line.charAt(index) == ';') {
if (insideSingleQuote || insideDoubleQuote || insideComment) {
// do not split
} else if (tracker.depth > 0) {
// do not split: this semicolon is a statement terminator inside a SQL Scripting
// compound block, not a boundary between top-level commands.
} else {
if (isStatement) {
// split, do not include ; itself
ret.add(line.substring(beginIndex, index))
ret += line.substring(beginIndex, index)
}
beginIndex = index + 1
isStatement = false
Expand Down Expand Up @@ -704,6 +863,10 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {

isStatement = statementInProgress(index)
}

// Flush any word that ends at the very last character of the input.
tracker.flush()

// Check the last char is end of nested bracketed comment.
val endOfBracketedComment = leavingBracketedComment && bracketedCommentLevel == 1
// Spark SQL support simple comment and nested bracketed comment in query body.
Expand All @@ -715,8 +878,8 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
// CLI should also pass this part to the backend engine, which may throw an exception
// with clear error message.
if (!endOfBracketedComment && (isStatement || insideBracketedComment)) {
ret.add(line.substring(beginIndex))
ret += line.substring(beginIndex)
}
ret
ret.result()
}
}
Loading