diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 0a024fb10ee01..7b254cd00b625 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -21,6 +21,7 @@ import java.io._ import java.util.{ArrayList => JArrayList, List => JList, Locale} import java.util.concurrent.TimeUnit +import scala.collection.mutable import scala.jdk.CollectionConverters._ import jline.console.ConsoleReader @@ -268,9 +269,19 @@ private[hive] object SparkSQLCLIDriver extends Logging { if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) { line = prefix + line - ret = cli.processLine(line, true) - prefix = "" - currentPrompt = promptWithCurrentDB + // For SQL Scripting, a semicolon inside a compound block (BEGIN ... END, + // IF ... END IF, WHILE ... DO ... END WHILE, etc.) terminates a statement + // *within* the block, not the block itself. Only fire when the accumulated + // input is no longer inside any open scripting block. + val insideScriptingBlock = sqlScriptingBlockDepth(line) > 0 + if (insideScriptingBlock) { + prefix = line + currentPrompt = continuedPromptWithDBSpaces + } else { + ret = cli.processLine(line, allowInterrupting = true) + prefix = "" + currentPrompt = promptWithCurrentDB + } } else { prefix = prefix + line currentPrompt = continuedPromptWithDBSpaces @@ -374,6 +385,131 @@ private[hive] object SparkSQLCLIDriver extends Logging { } Array[Completer](propCompleter, customCompleter) } + + private final val END_SUFFIX_KEYWORDS = + Set("IF", "CASE", "FOR", "WHILE", "LOOP", "REPEAT") + + /** + * Tracks SQL Scripting block nesting depth by processing SQL keyword tokens. + * + * Call [[processChar]] for each character that is outside quoted strings and comments. + * Call [[clearWordBuffer]] when entering a quoted string or comment mid-word. + * Call [[flush]] after the last character to finalize any pending token. + * Read [[depth]] to obtain the current block nesting level. + * + * Depth tracking rules: + * - BEGIN -> depth++ + * - CASE / IF (not `IF(`) / DO / LOOP / REPEAT -> depth++ only when depth > 0 + * - Every END keyword -> depth-- (when depth > 0) + * - Keywords immediately following END (IF, CASE, FOR, WHILE, LOOP, REPEAT) are + * decorative suffixes (e.g. "END IF") and do NOT change the depth. + */ + private[hive] class SqlScriptBlockTracker { + var depth: Int = 0 + private val wordBuf = new StringBuilder() + private var prevWord: String = "" + + private def isWordChar(c: Char): Boolean = c.isLetterOrDigit || c == '_' + + /** Feed a single character that is known to be outside any quote or comment. */ + def processChar(c: Char): Unit = { + if (isWordChar(c)) { + wordBuf.append(c) + } else if (wordBuf.nonEmpty) { + onWordEnd(wordBuf.toString, c) + wordBuf.clear() + } + } + + /** Discard a partially accumulated word (e.g. when entering a quote/comment). */ + def clearWordBuffer(): Unit = wordBuf.clear() + + /** Flush any remaining word at the end of the input. */ + def flush(): Unit = { + if (wordBuf.nonEmpty) { + onWordEnd(wordBuf.toString, ' ') + wordBuf.clear() + } + } + + private def onWordEnd(word: String, nextChar: Char): Unit = { + val upper = word.toUpperCase(Locale.ROOT) + if (prevWord == "END" && END_SUFFIX_KEYWORDS.contains(upper)) { + // Decorative suffix after END (e.g. "END IF") -- do not change depth. + prevWord = upper + return + } + prevWord = upper + upper match { + case "BEGIN" => depth += 1 + case "END" if depth > 0 => depth -= 1 + case "CASE" if depth > 0 => depth += 1 + case "IF" if depth > 0 && nextChar != '(' => depth += 1 + case "DO" if depth > 0 => depth += 1 + case "LOOP" if depth > 0 => depth += 1 + case "REPEAT" if depth > 0 => depth += 1 + case _ => + } + } + } + + /** + * Computes the SQL scripting block depth of the given SQL text. + * Returns 0 when the text is not inside any scripting block, > 0 when still open. + */ + private[hive] def sqlScriptingBlockDepth(text: String): Int = { + var insideSingleQuote = false + var insideDoubleQuote = false + var insideSimpleComment = false + var bracketedCommentLevel = 0 + var escape = false + var leavingBracketedComment = false + val tracker = new SqlScriptBlockTracker() + + def insideBracketedComment: Boolean = bracketedCommentLevel > 0 + def insideComment: Boolean = insideSimpleComment || insideBracketedComment + def insideAnyQuote: Boolean = insideSingleQuote || insideDoubleQuote + + for (index <- 0 until text.length) { + if (leavingBracketedComment) { + bracketedCommentLevel -= 1 + leavingBracketedComment = false + } + + val c = text.charAt(index) + + if (!insideComment && !insideAnyQuote) { + tracker.processChar(c) + } else if (tracker.depth >= 0) { + // Suppress partial tokens inside quotes/comments. + tracker.clearWordBuffer() + } + + if (c == '\'' && !insideComment) { + if (!escape && !insideDoubleQuote) insideSingleQuote = !insideSingleQuote + } else if (c == '\"' && !insideComment) { + if (!escape && !insideSingleQuote) insideDoubleQuote = !insideDoubleQuote + } else if (c == '-') { + val hasNext = index + 1 < text.length + if (!insideAnyQuote && !insideComment && hasNext && text.charAt(index + 1) == '-') { + insideSimpleComment = true + } + } else if (c == '\n' && !escape) { + insideSimpleComment = false + } else if (c == '/' && !insideSimpleComment && !insideAnyQuote) { + if (insideBracketedComment && index > 0 && text.charAt(index - 1) == '*') { + leavingBracketedComment = true + } else if (index + 1 < text.length && text.charAt(index + 1) == '*') { + bracketedCommentLevel += 1 + } + } + + if (escape) escape = false; else if (c == '\\') escape = true + } + + tracker.flush() + tracker.depth + } } private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { @@ -583,7 +719,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { var lastRet: Int = 0 // we can not use "split" function directly as ";" may be quoted - val commands = splitSemiColon(line).asScala + val commands = splitSemiColon(line) var command: String = "" for (oneCmd <- commands) { if (oneCmd.endsWith("\\")) { @@ -618,7 +754,12 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { // string, the origin implementation from Hive will not drop the trailing semicolon as expected, // hence we refined this function a little bit. // Note: [SPARK-33100] Ignore a semicolon inside a bracketed comment in spark-sql. - private[hive] def splitSemiColon(line: String): JList[String] = { + // Note: [SPARK-56147] Semicolons inside a SQL Scripting compound block (BEGIN ... END, + // IF ... END IF, WHILE/FOR ... DO ... END WHILE/FOR, LOOP ... END LOOP, REPEAT ... END REPEAT, + // CASE ... END CASE, and nested/labeled variants) terminate individual statements + // *within* the block and must not be used as split points. Block depth is tracked + // with the same keyword-aware scanner used by the interactive input loop. + private[hive] def splitSemiColon(line: String): Array[String] = { var insideSingleQuote = false var insideDoubleQuote = false var insideSimpleComment = false @@ -627,12 +768,15 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { var beginIndex = 0 var leavingBracketedComment = false var isStatement = false - val ret = new JArrayList[String] + val ret = mutable.ArrayBuilder.make[String] + + // For SQL Scripting block-depth tracking. + val tracker = new SparkSQLCLIDriver.SqlScriptBlockTracker() def insideBracketedComment: Boolean = bracketedCommentLevel > 0 def insideComment: Boolean = insideSimpleComment || insideBracketedComment def statementInProgress(index: Int): Boolean = isStatement || (!insideComment && - index > beginIndex && !s"${line.charAt(index)}".trim.isEmpty) + index > beginIndex && s"${line.charAt(index)}".trim.nonEmpty) for (index <- 0 until line.length) { // Checks if we need to decrement a bracketed comment level; the last character '/' of @@ -643,6 +787,18 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { leavingBracketedComment = false } + val c = line.charAt(index) + + // Accumulate keyword tokens for SQL Scripting block-depth tracking. + // The tracker is updated *before* the quote/comment state below is toggled, + // so that a closing quote or the start of a `--` comment correctly flushes any + // in-progress token first. + if (!insideComment && !insideSingleQuote && !insideDoubleQuote) { + tracker.processChar(c) + } else { + tracker.clearWordBuffer() + } + if (line.charAt(index) == '\'' && !insideComment) { // take a look to see if it is escaped // See the comment above about SPARK-31595 @@ -671,10 +827,13 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } else if (line.charAt(index) == ';') { if (insideSingleQuote || insideDoubleQuote || insideComment) { // do not split + } else if (tracker.depth > 0) { + // do not split: this semicolon is a statement terminator inside a SQL Scripting + // compound block, not a boundary between top-level commands. } else { if (isStatement) { // split, do not include ; itself - ret.add(line.substring(beginIndex, index)) + ret += line.substring(beginIndex, index) } beginIndex = index + 1 isStatement = false @@ -704,6 +863,10 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { isStatement = statementInProgress(index) } + + // Flush any word that ends at the very last character of the input. + tracker.flush() + // Check the last char is end of nested bracketed comment. val endOfBracketedComment = leavingBracketedComment && bracketedCommentLevel == 1 // Spark SQL support simple comment and nested bracketed comment in query body. @@ -715,8 +878,8 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { // CLI should also pass this part to the backend engine, which may throw an exception // with clear error message. if (!endOfBracketedComment && (isStatement || insideBracketedComment)) { - ret.add(line.substring(beginIndex)) + ret += line.substring(beginIndex) } - ret + ret.result() } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index d12f9fdd1900d..a28147b7a6f89 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -25,7 +25,6 @@ import java.util.concurrent.CountDownLatch import scala.collection.mutable.ArrayBuffer import scala.concurrent.Promise import scala.concurrent.duration._ -import scala.jdk.CollectionConverters._ import org.apache.hadoop.hive.cli.CliSessionState import org.apache.hadoop.hive.ql.session.SessionState @@ -680,12 +679,273 @@ class CliSuite extends SparkFunSuite { "-- comment \nSELECT 1" -> Seq("-- comment \nSELECT 1"), "/* comment */ " -> Seq() ).foreach { case (query, ret) => - assert(cli.splitSemiColon(query).asScala === ret) + assert(cli.splitSemiColon(query) === ret) } sessionState.close() SparkSQLEnv.stop() } + test("SQL Scripting: splitSemiColon should not split inside compound blocks") { + val sparkConf = new SparkConf(loadDefaults = true) + .setMaster("local-cluster[1,1,1024]") + .setAppName("sql-scripting-split") + val sparkContext = new SparkContext(sparkConf) + SparkSQLEnv.sparkContext = sparkContext + val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) + val cliConf = HiveClientImpl.newHiveConf(sparkConf, hadoopConf) + val sessionState = new CliSessionState(cliConf) + SessionState.setCurrentSessionState(sessionState) + val cli = new SparkSQLCLIDriver + + // Simple BEGIN...END + assert(cli.splitSemiColon( + """BEGIN + | SELECT 1; + |END""".stripMargin) === + Seq( + """BEGIN + | SELECT 1; + |END""".stripMargin)) + + // BEGIN...END with trailing semicolon + assert(cli.splitSemiColon( + """BEGIN + | SELECT 1; + |END;""".stripMargin) === + Seq( + """BEGIN + | SELECT 1; + |END""".stripMargin)) + + // Multiple statements inside the block + assert(cli.splitSemiColon( + """BEGIN + | SELECT 1; + | SELECT 2; + |END;""".stripMargin) === + Seq( + """BEGIN + | SELECT 1; + | SELECT 2; + |END""".stripMargin)) + + // Regular statements before and after a scripting block should still be split + assert(cli.splitSemiColon( + """SELECT 0; + |BEGIN + | SELECT 1; + | SELECT 2; + |END; + |SELECT 3;""".stripMargin) === + Seq( + "SELECT 0", + """ + |BEGIN + | SELECT 1; + | SELECT 2; + |END""".stripMargin, + """ + |SELECT 3""".stripMargin)) + + // IF...END IF inside a block + assert(cli.splitSemiColon( + """BEGIN + | IF x = 1 THEN + | SELECT 1; + | END IF; + |END;""".stripMargin) === + Seq( + """BEGIN + | IF x = 1 THEN + | SELECT 1; + | END IF; + |END""".stripMargin)) + + // WHILE...DO...END WHILE inside a block + assert(cli.splitSemiColon( + """BEGIN + | WHILE x > 0 DO + | SET x = x - 1; + | END WHILE; + |END;""".stripMargin) === + Seq( + """BEGIN + | WHILE x > 0 DO + | SET x = x - 1; + | END WHILE; + |END""".stripMargin)) + + // FOR...DO...END FOR inside a block + assert(cli.splitSemiColon( + """BEGIN + | FOR r AS SELECT * FROM t DO + | SELECT r.id; + | END FOR; + |END;""".stripMargin) === + Seq( + """BEGIN + | FOR r AS SELECT * FROM t DO + | SELECT r.id; + | END FOR; + |END""".stripMargin)) + + // LOOP...END LOOP inside a block + assert(cli.splitSemiColon( + """BEGIN + | LOOP + | SELECT 1; + | END LOOP; + |END;""".stripMargin) === + Seq( + """BEGIN + | LOOP + | SELECT 1; + | END LOOP; + |END""".stripMargin)) + + // REPEAT...END REPEAT inside a block + assert(cli.splitSemiColon( + """BEGIN + | REPEAT + | SELECT 1; + | UNTIL x > 0 END REPEAT; + |END;""".stripMargin) === + Seq( + """BEGIN + | REPEAT + | SELECT 1; + | UNTIL x > 0 END REPEAT; + |END""".stripMargin)) + + // CASE statement inside a block + assert(cli.splitSemiColon( + """BEGIN + | CASE x + | WHEN 1 THEN SELECT 'one'; + | END CASE; + |END;""".stripMargin) === + Seq( + """BEGIN + | CASE x + | WHEN 1 THEN SELECT 'one'; + | END CASE; + |END""".stripMargin)) + + // CASE expression (not a scripting block) inside a scripting block + assert(cli.splitSemiColon( + """BEGIN + | SELECT CASE WHEN x=1 THEN 'a' ELSE 'b' END; + |END;""".stripMargin) === + Seq( + """BEGIN + | SELECT CASE WHEN x=1 THEN 'a' ELSE 'b' END; + |END""".stripMargin)) + + // Nested BEGIN...END + assert(cli.splitSemiColon( + """BEGIN + | BEGIN + | SELECT 1; + | END; + | SELECT 2; + |END;""".stripMargin) === + Seq( + """BEGIN + | BEGIN + | SELECT 1; + | END; + | SELECT 2; + |END""".stripMargin)) + + // `IF(` is the Spark SQL built-in function, not a scripting IF -- the block must + // not be kept unsplit merely because IF appears inside it. + assert(cli.splitSemiColon( + """BEGIN + | SELECT IF(x > 0, 'pos', 'non-pos'); + |END;""".stripMargin) === + Seq( + """BEGIN + | SELECT IF(x > 0, 'pos', 'non-pos'); + |END""".stripMargin)) + + // Labeled block: label: BEGIN ... END label + assert(cli.splitSemiColon( + """outer: BEGIN + | SELECT 1; + |END outer;""".stripMargin) === + Seq( + """outer: BEGIN + | SELECT 1; + |END outer""".stripMargin)) + + sessionState.close() + SparkSQLEnv.stop() + } + + test("SQL Scripting: sqlScriptingBlockDepth correctly tracks nesting") { + import SparkSQLCLIDriver.sqlScriptingBlockDepth + + // Not a script + assert(sqlScriptingBlockDepth("SELECT 1") === 0) + // Simple open block + assert(sqlScriptingBlockDepth( + """BEGIN + | SELECT 1;""".stripMargin) === 1) + // Closed block + assert(sqlScriptingBlockDepth( + """BEGIN + | SELECT 1; + |END""".stripMargin) === 0) + // Trailing semicolon after END still closes + assert(sqlScriptingBlockDepth("BEGIN SELECT 1; END;") === 0) + // Nested block open + assert(sqlScriptingBlockDepth( + """BEGIN + | BEGIN + | SELECT 1;""".stripMargin) === 2) + // IF still open + assert(sqlScriptingBlockDepth( + """BEGIN + | IF x=1 THEN + | SELECT 1;""".stripMargin) === 2) + // IF closed + assert(sqlScriptingBlockDepth( + """BEGIN + | IF x=1 THEN + | SELECT 1; + | END IF;""".stripMargin) === 1) + // WHILE/DO still open + assert(sqlScriptingBlockDepth( + """BEGIN + | WHILE x>0 DO + | SET x=x-1;""".stripMargin) === 2) + // WHILE closed + assert(sqlScriptingBlockDepth( + """BEGIN + | WHILE x>0 DO + | SET x=x-1; + | END WHILE;""".stripMargin) === 1) + // CASE expression: CASE increments, END decrements -- net zero inside the block + assert(sqlScriptingBlockDepth( + """BEGIN + | SELECT CASE WHEN 1=1 THEN 'a' END;""".stripMargin) === 1) + // IF( is a Spark SQL function call, not a scripting IF -- must not increment + assert(sqlScriptingBlockDepth( + """BEGIN + | SELECT IF(1=1, 'a', 'b');""".stripMargin) === 1) + // Keywords inside string literals must be ignored + assert(sqlScriptingBlockDepth( + """BEGIN + | SELECT 'END';""".stripMargin) === 1) + // Keywords inside line comments must be ignored + assert(sqlScriptingBlockDepth( + """BEGIN + | -- END + | SELECT 1;""".stripMargin) === 1) + // Keywords inside bracketed comments must be ignored + assert(sqlScriptingBlockDepth("BEGIN /* END */ SELECT 1;") === 1) + } + testRetry("SPARK-39068: support in-memory catalog and running concurrently") { val extraConf = Seq("-c", s"${StaticSQLConf.CATALOG_IMPLEMENTATION.key}=in-memory") val cd = new CountDownLatch(2)