mirror of
https://github.com/apache/zeppelin
synced 2026-05-24 09:38:26 +00:00
[ZEPPELIN-4522]. Support multiple sql statements for SparkSqlInterpreter
This commit is contained in:
parent
aea44a4f0d
commit
bc3c1feff1
4 changed files with 70 additions and 19 deletions
|
|
@ -28,6 +28,7 @@ import org.apache.zeppelin.interpreter.InterpreterException;
|
|||
import org.apache.zeppelin.interpreter.InterpreterResult;
|
||||
import org.apache.zeppelin.interpreter.InterpreterResult.Code;
|
||||
import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
|
||||
import org.apache.zeppelin.interpreter.util.SqlSplitter;
|
||||
import org.apache.zeppelin.scheduler.Scheduler;
|
||||
import org.apache.zeppelin.scheduler.SchedulerFactory;
|
||||
import org.slf4j.Logger;
|
||||
|
|
@ -44,6 +45,7 @@ public class SparkSqlInterpreter extends AbstractInterpreter {
|
|||
private Logger logger = LoggerFactory.getLogger(SparkSqlInterpreter.class);
|
||||
|
||||
private SparkInterpreter sparkInterpreter;
|
||||
private SqlSplitter sqlSplitter;
|
||||
|
||||
public SparkSqlInterpreter(Properties property) {
|
||||
super(property);
|
||||
|
|
@ -52,6 +54,7 @@ public class SparkSqlInterpreter extends AbstractInterpreter {
|
|||
@Override
|
||||
public void open() throws InterpreterException {
|
||||
this.sparkInterpreter = getInterpreterInTheSameSessionByClassName(SparkInterpreter.class);
|
||||
this.sqlSplitter = new SqlSplitter();
|
||||
}
|
||||
|
||||
public boolean concurrentSQL() {
|
||||
|
|
@ -82,26 +85,35 @@ public class SparkSqlInterpreter extends AbstractInterpreter {
|
|||
sparkInterpreter.getZeppelinContext().setInterpreterContext(context);
|
||||
SQLContext sqlc = sparkInterpreter.getSQLContext();
|
||||
SparkContext sc = sqlc.sparkContext();
|
||||
sc.setLocalProperty("spark.scheduler.pool", context.getLocalProperties().get("pool"));
|
||||
sc.setJobGroup(Utils.buildJobGroupId(context), Utils.buildJobDesc(context), false);
|
||||
|
||||
try {
|
||||
Method method = sqlc.getClass().getMethod("sql", String.class);
|
||||
int maxResult = Integer.parseInt(context.getLocalProperties().getOrDefault("limit",
|
||||
"" + sparkInterpreter.getZeppelinContext().getMaxResult()));
|
||||
String msg = sparkInterpreter.getZeppelinContext().showData(
|
||||
method.invoke(sqlc, st), maxResult);
|
||||
sc.clearJobGroup();
|
||||
return new InterpreterResult(Code.SUCCESS, msg);
|
||||
} catch (Exception e) {
|
||||
if (Boolean.parseBoolean(getProperty("zeppelin.spark.sql.stacktrace"))) {
|
||||
return new InterpreterResult(Code.ERROR, ExceptionUtils.getStackTrace(e));
|
||||
StringBuilder builder = new StringBuilder();
|
||||
List<String> sqls = sqlSplitter.splitSql(st);
|
||||
for (String sql : sqls) {
|
||||
sc.setLocalProperty("spark.scheduler.pool", context.getLocalProperties().get("pool"));
|
||||
sc.setJobGroup(Utils.buildJobGroupId(context), Utils.buildJobDesc(context), false);
|
||||
|
||||
try {
|
||||
Method method = sqlc.getClass().getMethod("sql", String.class);
|
||||
int maxResult = Integer.parseInt(context.getLocalProperties().getOrDefault("limit",
|
||||
"" + sparkInterpreter.getZeppelinContext().getMaxResult()));
|
||||
String result = sparkInterpreter.getZeppelinContext().showData(
|
||||
method.invoke(sqlc, sql), maxResult);
|
||||
sc.clearJobGroup();
|
||||
builder.append(result);
|
||||
} catch (Exception e) {
|
||||
if (Boolean.parseBoolean(getProperty("zeppelin.spark.sql.stacktrace"))) {
|
||||
builder.append("%text " + ExceptionUtils.getStackTrace(e));
|
||||
} else {
|
||||
logger.error("Invocation target exception", e);
|
||||
String msg = e.getCause().getMessage()
|
||||
+ "\nset zeppelin.spark.sql.stacktrace = true to see full stacktrace";
|
||||
builder.append("\n%text " + msg);
|
||||
}
|
||||
return new InterpreterResult(Code.ERROR, builder.toString());
|
||||
}
|
||||
logger.error("Invocation target exception", e);
|
||||
String msg = e.getCause().getMessage()
|
||||
+ "\nset zeppelin.spark.sql.stacktrace = true to see full stacktrace";
|
||||
return new InterpreterResult(Code.ERROR, msg);
|
||||
}
|
||||
|
||||
return new InterpreterResult(Code.SUCCESS, builder.toString());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
|||
|
|
@ -170,6 +170,45 @@ public class SparkSqlInterpreterTest {
|
|||
assertEquals(6, ret.message().get(0).getData().split("\n").length);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultipleStatements() throws InterpreterException {
|
||||
sparkInterpreter.interpret("case class P(age:Int)", context);
|
||||
sparkInterpreter.interpret(
|
||||
"val gr = sc.parallelize(Seq(P(1),P(2),P(3),P(4),P(5),P(6),P(7),P(8)))",
|
||||
context);
|
||||
sparkInterpreter.interpret("gr.toDF.registerTempTable(\"gr\")", context);
|
||||
|
||||
// Two correct sql
|
||||
InterpreterResult ret = sqlInterpreter.interpret(
|
||||
"select * --comment_1\nfrom gr;select count(1) from gr", context);
|
||||
assertEquals(InterpreterResult.Code.SUCCESS, ret.code());
|
||||
assertEquals(ret.message().toString(), 5, ret.message().size());
|
||||
assertEquals(ret.message().toString(), Type.TEXT, ret.message().get(0).getType());
|
||||
assertEquals(ret.message().toString(), "\n", ret.message().get(0).getData());
|
||||
assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(1).getType());
|
||||
assertEquals(ret.message().toString(), Type.TEXT, ret.message().get(2).getType());
|
||||
assertEquals(ret.message().toString(), "\n", ret.message().get(2).getData());
|
||||
assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(3).getType());
|
||||
assertEquals(ret.message().toString(), Type.TEXT, ret.message().get(4).getType());
|
||||
assertEquals(ret.message().toString(), "", ret.message().get(4).getData());
|
||||
|
||||
// One correct sql + One invalid sql
|
||||
ret = sqlInterpreter.interpret("select * from gr;invalid_sql", context);
|
||||
assertEquals(InterpreterResult.Code.ERROR, ret.code());
|
||||
assertEquals(ret.message().toString(), 3, ret.message().size());
|
||||
assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(1).getType());
|
||||
assertEquals(ret.message().toString(), Type.TEXT, ret.message().get(2).getType());
|
||||
assertTrue(ret.message().toString(), ret.message().get(2).getData().contains("ParseException"));
|
||||
|
||||
// One correct sql + One invalid sql + One valid sql (skipped)
|
||||
ret = sqlInterpreter.interpret("select * from gr;invalid_sql; select count(1) from gr", context);
|
||||
assertEquals(InterpreterResult.Code.ERROR, ret.code());
|
||||
assertEquals(ret.message().toString(), 3, ret.message().size());
|
||||
assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(1).getType());
|
||||
assertEquals(ret.message().toString(), Type.TEXT, ret.message().get(2).getType());
|
||||
assertTrue(ret.message().toString(), ret.message().get(2).getData().contains("ParseException"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConcurrentSQL() throws InterpreterException, InterruptedException {
|
||||
if (sparkInterpreter.getSparkVersion().isSpark2()) {
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ public class Spark1Shims extends SparkShims {
|
|||
// fetch maxResult+1 rows so that we can check whether it is larger than zeppelin.spark.maxResult
|
||||
List<Row> rows = df.takeAsList(maxResult + 1);
|
||||
StringBuilder msg = new StringBuilder();
|
||||
msg.append("%table ");
|
||||
msg.append("\n%table ");
|
||||
msg.append(StringUtils.join(columns, "\t"));
|
||||
msg.append("\n");
|
||||
boolean isLargerThanMaxResult = rows.size() > maxResult;
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ public class Spark2Shims extends SparkShims {
|
|||
// fetch maxResult+1 rows so that we can check whether it is larger than zeppelin.spark.maxResult
|
||||
List<Row> rows = df.takeAsList(maxResult + 1);
|
||||
StringBuilder msg = new StringBuilder();
|
||||
msg.append("%table ");
|
||||
msg.append("\n%table ");
|
||||
msg.append(StringUtils.join(columns, "\t"));
|
||||
msg.append("\n");
|
||||
boolean isLargerThanMaxResult = rows.size() > maxResult;
|
||||
|
|
|
|||
Loading…
Reference in a new issue