ZEPPELIN-3617. Allow to specify saving resourceName as paragraph property

This commit is contained in:
Jeff Zhang 2018-07-12 14:28:02 +08:00
parent 64bf4bb3d0
commit 024965e645
17 changed files with 249 additions and 28 deletions

View file

@ -559,6 +559,8 @@ public class PythonInterpreter extends Interpreter {
InterpreterContext.get());
if (result.code() != Code.SUCCESS) {
throw new IOException("Fail to run bootstrap script: " + resourceName + "\n" + result);
} else {
LOGGER.debug("Bootstrap python successfully.");
}
} catch (InterpreterException e) {
throw new IOException(e);

View file

@ -61,6 +61,14 @@ class PyZeppelinContext(object):
def get(self, key):
return self.__getitem__(key)
def getAsDataFrame(self, key):
value = self.get(key)
try:
import pandas as pd
except ImportError:
print("fail to call getAsDataFrame as pandas is not installed")
return pd.read_csv(StringIO(value), sep="\t")
def angular(self, key, noteId = None, paragraphId = None):
return self.z.angular(key, noteId, paragraphId)
@ -158,6 +166,7 @@ class PyZeppelinContext(object):
body_buf = StringIO("")
rows = df.head(self.max_result).values if exceed_limit else df.values
rowNumber = len(rows)
index = df.index.values
for idx, row in zip(index, rows):
if show_index:
@ -167,13 +176,16 @@ class PyZeppelinContext(object):
for cell in row[1:]:
body_buf.write("\t")
body_buf.write(str(cell))
body_buf.write("\n")
# don't print '\n' after the last row
if idx != (rowNumber - 1):
body_buf.write("\n")
body_buf.seek(0)
header_buf.seek(0)
print("%table " + header_buf.read() + body_buf.read())
body_buf.close(); header_buf.close()
body_buf.close()
header_buf.close()
if exceed_limit:
print("%html <font color=red>Results are limited by {}.</font>".format(self.max_result))
print("\n%html <font color=red>Results are limited by {}.</font>".format(self.max_result))
def show_matplotlib(self, p, fmt="png", width="auto", height="auto",
**kwargs):

View file

@ -98,6 +98,7 @@ public class PySparkInterpreter extends PythonInterpreter {
try {
bootstrapInterpreter("python/zeppelin_pyspark.py");
} catch (IOException e) {
LOGGER.error("Fail to bootstrap pyspark", e);
throw new InterpreterException("Fail to bootstrap pyspark", e);
}
}

View file

@ -66,6 +66,12 @@ z.put <- function(name, object) {
z.get <- function(name) {
SparkR:::callJMethod(.zeppelinContext, "get", name)
}
z.getAsDataFrame <- function(name) {
stringValue <- z.get(name)
read.table(text=stringValue, header=TRUE, sep="\t")
}
z.angular <- function(name, noteId=NULL, paragraphId=NULL) {
SparkR:::callJMethod(.zeppelinContext, "angular", name, noteId, paragraphId)
}

View file

@ -99,6 +99,11 @@ public class SparkShimsTest {
public String showDataFrame(Object obj, int maxResult) {
return null;
}
@Override
public Object getAsDataFrame(String value) {
return null;
}
};
assertEquals(expected, sparkShims.supportYarn6615(version));
}
@ -121,9 +126,9 @@ public class SparkShimsTest {
when(mockContext.getIntpEventClient()).thenReturn(mockIntpEventClient);
doNothing().when(mockIntpEventClient).onParaInfosReceived(argumentCaptor.capture());
try {
sparkShims = SparkShims.getInstance(SparkVersion.SPARK_2_0_0.toString(), new Properties());
sparkShims = SparkShims.getInstance(SparkVersion.SPARK_2_0_0.toString(), new Properties(), null);
} catch (Throwable ignore) {
sparkShims = SparkShims.getInstance(SparkVersion.SPARK_1_6_0.toString(), new Properties());
sparkShims = SparkShims.getInstance(SparkVersion.SPARK_1_6_0.toString(), new Properties(), null);
}
}

View file

@ -303,7 +303,13 @@ abstract class BaseSparkScalaInterpreter(val conf: SparkConf,
}
protected def createZeppelinContext(): Unit = {
val sparkShims = SparkShims.getInstance(sc.version, properties)
var sparkShims: SparkShims = null
if (isSparkSessionPresent()) {
sparkShims = SparkShims.getInstance(sc.version, properties, sparkSession)
} else {
sparkShims = SparkShims.getInstance(sc.version, properties, sc)
}
var webUiUrl = properties.getProperty("zeppelin.spark.uiWebUrl");
if (StringUtils.isBlank(webUiUrl)) {
webUiUrl = sparkUrl;

View file

@ -20,6 +20,7 @@ package org.apache.zeppelin.spark
import java.util
import org.apache.spark.SparkContext
import org.apache.spark.sql.DataFrame
import org.apache.zeppelin.annotation.ZeppelinApi
import org.apache.zeppelin.display.AngularObjectWatcher
import org.apache.zeppelin.display.ui.OptionInput.ParamOption
@ -146,4 +147,8 @@ class SparkZeppelinContext(val sc: SparkContext,
}
angularWatch(name, noteId, w)
}
def getAsDataFrame(name: String): Object = {
sparkShims.getAsDataFrame(get(name).toString)
}
}

View file

@ -54,7 +54,7 @@ public abstract class SparkShims {
this.properties = properties;
}
private static SparkShims loadShims(String sparkVersion, Properties properties)
private static SparkShims loadShims(String sparkVersion, Properties properties, Object entryPoint)
throws ReflectiveOperationException {
Class<?> sparkShimsClass;
if ("2".equals(sparkVersion)) {
@ -65,15 +65,22 @@ public abstract class SparkShims {
sparkShimsClass = Class.forName("org.apache.zeppelin.spark.Spark1Shims");
}
Constructor c = sparkShimsClass.getConstructor(Properties.class);
return (SparkShims) c.newInstance(properties);
Constructor c = sparkShimsClass.getConstructor(Properties.class, Object.class);
return (SparkShims) c.newInstance(properties, entryPoint);
}
public static SparkShims getInstance(String sparkVersion, Properties properties) {
/**
*
* @param sparkVersion
* @param properties
* @param entryPoint entryPoint is SparkContext for Spark 1.x SparkSession for Spark 2.x
* @return
*/
public static SparkShims getInstance(String sparkVersion, Properties properties, Object entryPoint) {
if (sparkShims == null) {
String sparkMajorVersion = getSparkMajorVersion(sparkVersion);
try {
sparkShims = loadShims(sparkMajorVersion, properties);
sparkShims = loadShims(sparkMajorVersion, properties, entryPoint);
} catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
@ -95,6 +102,7 @@ public abstract class SparkShims {
public abstract String showDataFrame(Object obj, int maxResult);
public abstract Object getAsDataFrame(String value);
protected void buildSparkJobUrl(String master,
String sparkWebUrl,

View file

@ -23,17 +23,24 @@ import org.apache.spark.SparkContext;
import org.apache.spark.scheduler.SparkListenerJobStart;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.ui.jobs.JobProgressListener;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.ResultMessages;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
public class Spark1Shims extends SparkShims {
public Spark1Shims(Properties properties) {
private SparkContext sc;
public Spark1Shims(Properties properties, Object entryPoint) {
super(properties);
this.sc = (SparkContext) entryPoint;
}
public void setupSparkListener(final String master,
@ -91,4 +98,24 @@ public class Spark1Shims extends SparkShims {
return obj.toString();
}
}
@Override
public DataFrame getAsDataFrame(String value) {
String[] lines = value.split("\\n");
String head = lines[0];
String[] columns = head.split("\t");
StructType schema = new StructType();
for (String column : columns) {
schema = schema.add(column, "String");
}
List<Row> rows = new ArrayList<>();
for (int i = 1; i < lines.length; ++i) {
String[] tokens = lines[i].split("\t");
Row row = new GenericRow(tokens);
rows.add(row);
}
return SQLContext.getOrCreate(sc)
.createDataFrame(rows, schema);
}
}

View file

@ -24,16 +24,23 @@ import org.apache.spark.scheduler.SparkListener;
import org.apache.spark.scheduler.SparkListenerJobStart;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.types.StructType;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.ResultMessages;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
public class Spark2Shims extends SparkShims {
public Spark2Shims(Properties properties) {
private SparkSession sparkSession;
public Spark2Shims(Properties properties, Object entryPoint) {
super(properties);
this.sparkSession = (SparkSession) entryPoint;
}
public void setupSparkListener(final String master,
@ -93,4 +100,22 @@ public class Spark2Shims extends SparkShims {
}
}
@Override
public Dataset<Row> getAsDataFrame(String value) {
String[] lines = value.split("\\n");
String head = lines[0];
String[] columns = head.split("\t");
StructType schema = new StructType();
for (String column : columns) {
schema = schema.add(column, "String");
}
List<Row> rows = new ArrayList<>();
for (int i = 1; i < lines.length; ++i) {
String[] tokens = lines[i].split("\t");
Row row = new GenericRow(tokens);
rows.add(row);
}
return sparkSession.createDataFrame(rows, schema);
}
}

View file

@ -64,6 +64,7 @@ public class JdbcIntegrationTest {
interpreterSetting.setProperty("default.driver", "com.mysql.jdbc.Driver");
interpreterSetting.setProperty("default.url", "jdbc:mysql://localhost:3306/");
interpreterSetting.setProperty("default.user", "root");
Dependency dependency = new Dependency("mysql:mysql-connector-java:5.1.46");
interpreterSetting.setDependencies(Lists.newArrayList(dependency));
interpreterSettingManager.restart(interpreterSetting.getId());
@ -78,5 +79,27 @@ public class JdbcIntegrationTest {
.build();
InterpreterResult interpreterResult = jdbcInterpreter.interpret("show databases;", context);
assertEquals(interpreterResult.toString(), InterpreterResult.Code.SUCCESS, interpreterResult.code());
context.getLocalProperties().put("saveAs", "table_1");
interpreterResult = jdbcInterpreter.interpret("SELECT 1 as c1, 2 as c2;", context);
assertEquals(interpreterResult.toString(), InterpreterResult.Code.SUCCESS, interpreterResult.code());
assertEquals(1, interpreterResult.message().size());
assertEquals(InterpreterResult.Type.TABLE, interpreterResult.message().get(0).getType());
assertEquals("c1\tc2\n1\t2\n", interpreterResult.message().get(0).getData());
// read table_1 from python interpreter
Interpreter pythonInterpreter = interpreterFactory.getInterpreter("user1", "note1", "python", "test");
assertNotNull("PythonInterpreter is null", pythonInterpreter);
context = new InterpreterContext.Builder()
.setNoteId("note1")
.setParagraphId("paragraph_1")
.setAuthenticationInfo(AuthenticationInfo.ANONYMOUS)
.build();
interpreterResult = pythonInterpreter.interpret("df=z.getAsDataFrame('table_1')\nz.show(df)", context);
assertEquals(interpreterResult.toString(), InterpreterResult.Code.SUCCESS, interpreterResult.code());
assertEquals(1, interpreterResult.message().size());
assertEquals(InterpreterResult.Type.TABLE, interpreterResult.message().get(0).getType());
assertEquals("c1\tc2\n1\t2\n", interpreterResult.message().get(0).getData());
}
}

View file

@ -279,8 +279,15 @@ public abstract class ZeppelinSparkClusterTest extends AbstractTestRestApi {
note = TestUtils.getInstance(Notebook.class).createNote("note1", anonymous);
// test basic dataframe api
Paragraph p = note.addNewParagraph(anonymous);
p.setText("%spark val df=sqlContext.createDataFrame(Seq((\"hello\",20)))\n" +
"df.collect()");
if (isSpark2()) {
p.setText("%spark val df=spark.createDataFrame(Seq((\"hello\",20)))" +
".toDF(\"name\", \"age\")\n" +
"df.collect()");
} else {
p.setText("%spark val df=sqlContext.createDataFrame(Seq((\"hello\",20)))" +
".toDF(\"name\", \"age\")\n" +
"df.collect()");
}
note.run(p.getId(), true);
assertEquals(Status.FINISHED, p.getStatus());
assertTrue(p.getReturn().message().get(0).getData().contains(
@ -288,12 +295,62 @@ public abstract class ZeppelinSparkClusterTest extends AbstractTestRestApi {
// test display DataFrame
p = note.addNewParagraph(anonymous);
p.setText("%spark val df=sqlContext.createDataFrame(Seq((\"hello\",20)))\n" +
"z.show(df)");
if (isSpark2()) {
p.setText("%spark val df=spark.createDataFrame(Seq((\"hello\",20)))" +
".toDF(\"name\", \"age\")\n" +
"df.createOrReplaceTempView(\"test_table\")\n" +
"z.show(df)");
} else {
p.setText("%spark val df=sqlContext.createDataFrame(Seq((\"hello\",20)))" +
".toDF(\"name\", \"age\")\n" +
"df.registerTempTable(\"test_table\")\n" +
"z.show(df)");
}
note.run(p.getId(), true);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TABLE, p.getReturn().message().get(0).getType());
assertEquals("_1\t_2\nhello\t20\n", p.getReturn().message().get(0).getData());
assertEquals("name\tage\nhello\t20\n", p.getReturn().message().get(0).getData());
// run sql and save it into resource pool
p = note.addNewParagraph(anonymous);
p.setText("%spark.sql(saveAs=table_result) select * from test_table");
note.run(p.getId(), true);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TABLE, p.getReturn().message().get(0).getType());
assertEquals("name\tage\nhello\t20\n", p.getReturn().message().get(0).getData());
// get resource from spark
p = note.addNewParagraph(anonymous);
p.setText("%spark val df=z.getAsDataFrame(\"table_result\")\nz.show(df)");
note.run(p.getId(), true);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TABLE, p.getReturn().message().get(0).getType());
assertEquals("name\tage\nhello\t20\n", p.getReturn().message().get(0).getData());
// get resource from pyspark
p = note.addNewParagraph(anonymous);
p.setText("%spark.pyspark df=z.getAsDataFrame('table_result')\nz.show(df)");
note.run(p.getId(), true);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TABLE, p.getReturn().message().get(0).getType());
assertEquals("name\tage\nhello\t20\n", p.getReturn().message().get(0).getData());
// get resource from ipyspark
p = note.addNewParagraph(anonymous);
p.setText("%spark.ipyspark df=z.getAsDataFrame('table_result')\nz.show(df)");
note.run(p.getId(), true);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TABLE, p.getReturn().message().get(0).getType());
assertEquals("name\tage\nhello\t20\n", p.getReturn().message().get(0).getData());
// get resource from sparkr
p = note.addNewParagraph(anonymous);
p.setText("%spark.r df=z.getAsDataFrame('table_result')\ndf");
note.run(p.getId(), true);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TEXT, p.getReturn().message().get(0).getType());
assertTrue(p.getReturn().toString(),
p.getReturn().message().get(0).getData().contains("name age\n1 hello 20"));
// test display DataSet
if (isSpark2()) {
@ -592,6 +649,13 @@ public abstract class ZeppelinSparkClusterTest extends AbstractTestRestApi {
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(sparkVersion, p.getReturn().message().get(0).getData());
p.setText("%spark.pyspark sc.version");
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertTrue(p.getReturn().toString(),
p.getReturn().message().get(0).getData().contains(sparkVersion));
} finally {
if (null != note) {
TestUtils.getInstance(Notebook.class).removeNote(note.getId(), anonymous);

View file

@ -852,6 +852,25 @@ public abstract class BaseZeppelinContext {
}
}
/**
* Get object from resource pool
* Search local process first and then the other processes
*
* @param name
* @param clazz The class of the returned value
* @return null if resource not found
*/
@ZeppelinApi
public <T> T get(String name, Class<T> clazz) {
ResourcePool resourcePool = interpreterContext.getResourcePool();
Resource resource = resourcePool.get(name);
if (resource != null) {
return resource.get(clazz);
} else {
return null;
}
}
/**
* Remove object from resourcePool
*

View file

@ -69,7 +69,6 @@ import org.apache.zeppelin.resource.DistributedResourcePool;
import org.apache.zeppelin.resource.Resource;
import org.apache.zeppelin.resource.ResourcePool;
import org.apache.zeppelin.resource.ResourceSet;
import org.apache.zeppelin.resource.WellKnownResourceName;
import org.apache.zeppelin.scheduler.Job;
import org.apache.zeppelin.scheduler.Job.Status;
import org.apache.zeppelin.scheduler.JobListener;
@ -88,6 +87,7 @@ import java.lang.reflect.Method;
import java.net.URL;
import java.nio.ByteBuffer;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
@ -679,24 +679,35 @@ public class RemoteInterpreterServer extends Thread
// data from context.out is prepended to InterpreterResult if both defined
context.out.flush();
List<InterpreterResultMessage> resultMessages = context.out.toInterpreterResultMessage();
resultMessages.addAll(result.message());
for (InterpreterResultMessage resultMessage : result.message()) {
// only add non-empty InterpreterResultMessage
if (!StringUtils.isBlank(resultMessage.getData())) {
resultMessages.add(resultMessage);
}
}
List<String> stringResult = new ArrayList<>();
for (InterpreterResultMessage msg : resultMessages) {
if (msg.getType() == InterpreterResult.Type.IMG) {
logger.debug("InterpreterResultMessage: IMAGE_DATA");
} else {
logger.debug("InterpreterResultMessage: " + msg.toString());
}
stringResult.add(msg.getData());
}
// put result into resource pool
if (resultMessages.size() > 0) {
int lastMessageIndex = resultMessages.size() - 1;
if (resultMessages.get(lastMessageIndex).getType() == InterpreterResult.Type.TABLE) {
if (context.getLocalProperties().containsKey("saveAs")) {
if (stringResult.size() == 1) {
logger.info("Saving result into ResourcePool as single string: " +
context.getLocalProperties().get("saveAs"));
context.getResourcePool().put(
context.getNoteId(),
context.getParagraphId(),
WellKnownResourceName.ZeppelinTableResult.toString(),
resultMessages.get(lastMessageIndex));
context.getLocalProperties().get("saveAs"), stringResult.get(0));
} else {
logger.info("Saving result into ResourcePool as string list: " +
context.getLocalProperties().get("saveAs"));
context.getResourcePool().put(
context.getLocalProperties().get("saveAs"), stringResult);
}
}
return new InterpreterResult(result.code(), resultMessages);

View file

@ -56,6 +56,8 @@ public class DistributedResourcePool extends LocalResourcePool {
if (resources.isEmpty()) {
return null;
} else {
// TODO(zjffdu) just assume there's no dupicated resources with the same name, but
// this assumption is false
return resources.get(0);
}
} else {

View file

@ -17,10 +17,10 @@
package org.apache.zeppelin.resource;
import com.google.gson.Gson;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import com.google.gson.internal.Primitives;
import org.apache.zeppelin.common.JsonSerializable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -94,6 +94,10 @@ public class Resource implements JsonSerializable, Serializable {
}
}
public <T> T get(Class<T> clazz) {
return Primitives.wrap(clazz).cast(r);
}
public boolean isSerializable() {
return serializable;
}

View file

@ -26,6 +26,7 @@ import org.apache.zeppelin.common.JsonSerializable;
public class ResourceId implements JsonSerializable, Serializable {
private static final Gson gson = new Gson();
// resourcePoolId is the interpreterGroupId which is unique across one Zeppelin instance
private final String resourcePoolId;
private final String name;
private final String noteId;