From bc3c1feff1ea2fa6ea9fabb01693df49a44fb719 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Sat, 28 Dec 2019 18:05:18 +0800 Subject: [PATCH] [ZEPPELIN-4522]. Support multiple sql statements for SparkSqlInterpreter --- .../zeppelin/spark/SparkSqlInterpreter.java | 46 ++++++++++++------- .../spark/SparkSqlInterpreterTest.java | 39 ++++++++++++++++ .../apache/zeppelin/spark/Spark1Shims.java | 2 +- .../apache/zeppelin/spark/Spark2Shims.java | 2 +- 4 files changed, 70 insertions(+), 19 deletions(-) diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java index 4e63760c09..889710c385 100644 --- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java @@ -28,6 +28,7 @@ import org.apache.zeppelin.interpreter.InterpreterException; import org.apache.zeppelin.interpreter.InterpreterResult; import org.apache.zeppelin.interpreter.InterpreterResult.Code; import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import org.apache.zeppelin.interpreter.util.SqlSplitter; import org.apache.zeppelin.scheduler.Scheduler; import org.apache.zeppelin.scheduler.SchedulerFactory; import org.slf4j.Logger; @@ -44,6 +45,7 @@ public class SparkSqlInterpreter extends AbstractInterpreter { private Logger logger = LoggerFactory.getLogger(SparkSqlInterpreter.class); private SparkInterpreter sparkInterpreter; + private SqlSplitter sqlSplitter; public SparkSqlInterpreter(Properties property) { super(property); @@ -52,6 +54,7 @@ public class SparkSqlInterpreter extends AbstractInterpreter { @Override public void open() throws InterpreterException { this.sparkInterpreter = getInterpreterInTheSameSessionByClassName(SparkInterpreter.class); + this.sqlSplitter = new SqlSplitter(); } public boolean concurrentSQL() { @@ -82,26 +85,35 @@ public class SparkSqlInterpreter extends AbstractInterpreter { sparkInterpreter.getZeppelinContext().setInterpreterContext(context); SQLContext sqlc = sparkInterpreter.getSQLContext(); SparkContext sc = sqlc.sparkContext(); - sc.setLocalProperty("spark.scheduler.pool", context.getLocalProperties().get("pool")); - sc.setJobGroup(Utils.buildJobGroupId(context), Utils.buildJobDesc(context), false); - try { - Method method = sqlc.getClass().getMethod("sql", String.class); - int maxResult = Integer.parseInt(context.getLocalProperties().getOrDefault("limit", - "" + sparkInterpreter.getZeppelinContext().getMaxResult())); - String msg = sparkInterpreter.getZeppelinContext().showData( - method.invoke(sqlc, st), maxResult); - sc.clearJobGroup(); - return new InterpreterResult(Code.SUCCESS, msg); - } catch (Exception e) { - if (Boolean.parseBoolean(getProperty("zeppelin.spark.sql.stacktrace"))) { - return new InterpreterResult(Code.ERROR, ExceptionUtils.getStackTrace(e)); + StringBuilder builder = new StringBuilder(); + List sqls = sqlSplitter.splitSql(st); + for (String sql : sqls) { + sc.setLocalProperty("spark.scheduler.pool", context.getLocalProperties().get("pool")); + sc.setJobGroup(Utils.buildJobGroupId(context), Utils.buildJobDesc(context), false); + + try { + Method method = sqlc.getClass().getMethod("sql", String.class); + int maxResult = Integer.parseInt(context.getLocalProperties().getOrDefault("limit", + "" + sparkInterpreter.getZeppelinContext().getMaxResult())); + String result = sparkInterpreter.getZeppelinContext().showData( + method.invoke(sqlc, sql), maxResult); + sc.clearJobGroup(); + builder.append(result); + } catch (Exception e) { + if (Boolean.parseBoolean(getProperty("zeppelin.spark.sql.stacktrace"))) { + builder.append("%text " + ExceptionUtils.getStackTrace(e)); + } else { + logger.error("Invocation target exception", e); + String msg = e.getCause().getMessage() + + "\nset zeppelin.spark.sql.stacktrace = true to see full stacktrace"; + builder.append("\n%text " + msg); + } + return new InterpreterResult(Code.ERROR, builder.toString()); } - logger.error("Invocation target exception", e); - String msg = e.getCause().getMessage() - + "\nset zeppelin.spark.sql.stacktrace = true to see full stacktrace"; - return new InterpreterResult(Code.ERROR, msg); } + + return new InterpreterResult(Code.SUCCESS, builder.toString()); } @Override diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java index cab5b1b47e..71b075bca7 100644 --- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java @@ -170,6 +170,45 @@ public class SparkSqlInterpreterTest { assertEquals(6, ret.message().get(0).getData().split("\n").length); } + @Test + public void testMultipleStatements() throws InterpreterException { + sparkInterpreter.interpret("case class P(age:Int)", context); + sparkInterpreter.interpret( + "val gr = sc.parallelize(Seq(P(1),P(2),P(3),P(4),P(5),P(6),P(7),P(8)))", + context); + sparkInterpreter.interpret("gr.toDF.registerTempTable(\"gr\")", context); + + // Two correct sql + InterpreterResult ret = sqlInterpreter.interpret( + "select * --comment_1\nfrom gr;select count(1) from gr", context); + assertEquals(InterpreterResult.Code.SUCCESS, ret.code()); + assertEquals(ret.message().toString(), 5, ret.message().size()); + assertEquals(ret.message().toString(), Type.TEXT, ret.message().get(0).getType()); + assertEquals(ret.message().toString(), "\n", ret.message().get(0).getData()); + assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(1).getType()); + assertEquals(ret.message().toString(), Type.TEXT, ret.message().get(2).getType()); + assertEquals(ret.message().toString(), "\n", ret.message().get(2).getData()); + assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(3).getType()); + assertEquals(ret.message().toString(), Type.TEXT, ret.message().get(4).getType()); + assertEquals(ret.message().toString(), "", ret.message().get(4).getData()); + + // One correct sql + One invalid sql + ret = sqlInterpreter.interpret("select * from gr;invalid_sql", context); + assertEquals(InterpreterResult.Code.ERROR, ret.code()); + assertEquals(ret.message().toString(), 3, ret.message().size()); + assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(1).getType()); + assertEquals(ret.message().toString(), Type.TEXT, ret.message().get(2).getType()); + assertTrue(ret.message().toString(), ret.message().get(2).getData().contains("ParseException")); + + // One correct sql + One invalid sql + One valid sql (skipped) + ret = sqlInterpreter.interpret("select * from gr;invalid_sql; select count(1) from gr", context); + assertEquals(InterpreterResult.Code.ERROR, ret.code()); + assertEquals(ret.message().toString(), 3, ret.message().size()); + assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(1).getType()); + assertEquals(ret.message().toString(), Type.TEXT, ret.message().get(2).getType()); + assertTrue(ret.message().toString(), ret.message().get(2).getData().contains("ParseException")); + } + @Test public void testConcurrentSQL() throws InterpreterException, InterruptedException { if (sparkInterpreter.getSparkVersion().isSpark2()) { diff --git a/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java b/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java index 8e60ed07fd..6119647bc5 100644 --- a/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java +++ b/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java @@ -70,7 +70,7 @@ public class Spark1Shims extends SparkShims { // fetch maxResult+1 rows so that we can check whether it is larger than zeppelin.spark.maxResult List rows = df.takeAsList(maxResult + 1); StringBuilder msg = new StringBuilder(); - msg.append("%table "); + msg.append("\n%table "); msg.append(StringUtils.join(columns, "\t")); msg.append("\n"); boolean isLargerThanMaxResult = rows.size() > maxResult; diff --git a/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java b/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java index a7304c5cac..b7b1cf9a93 100644 --- a/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java +++ b/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java @@ -71,7 +71,7 @@ public class Spark2Shims extends SparkShims { // fetch maxResult+1 rows so that we can check whether it is larger than zeppelin.spark.maxResult List rows = df.takeAsList(maxResult + 1); StringBuilder msg = new StringBuilder(); - msg.append("%table "); + msg.append("\n%table "); msg.append(StringUtils.join(columns, "\t")); msg.append("\n"); boolean isLargerThanMaxResult = rows.size() > maxResult;