add unit test

This commit is contained in:
Jeff Zhang 2016-09-22 11:25:24 +08:00
parent a7ba67d9a7
commit 62dbcadfa7
3 changed files with 76 additions and 0 deletions

View file

@ -206,6 +206,12 @@ public class ZeppelinContext {
sc.setJobGroup(jobGroup, "Zeppelin", false);
try {
// convert it to DataFrame if it is Dataset, as we will iterate all the records
// and assume it is type Row.
if (df.getClass().getCanonicalName().equals("org.apache.spark.sql.Dataset")) {
Method convertToDFMethod = df.getClass().getMethod("toDF");
df = convertToDFMethod.invoke(df);
}
take = df.getClass().getMethod("take", int.class);
rows = (Object[]) take.invoke(df, maxResult + 1);
} catch (NoSuchMethodException | SecurityException | IllegalAccessException

View file

@ -504,11 +504,13 @@ public class RemoteInterpreterServer
return new InterpreterOutput(new InterpreterOutputListener() {
@Override
public void onAppend(InterpreterOutput out, byte[] line) {
logger.debug("Output Append:" + new String(line));
eventClient.onInterpreterOutputAppend(noteId, paragraphId, new String(line));
}
@Override
public void onUpdate(InterpreterOutput out, byte[] output) {
logger.debug("Output Update:" + new String(output));
eventClient.onInterpreterOutputUpdate(noteId, paragraphId, new String(output));
}
});

View file

@ -17,6 +17,7 @@
package org.apache.zeppelin.rest;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.io.File;
import java.io.IOException;
@ -24,6 +25,7 @@ import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterSetting;
import org.apache.zeppelin.notebook.Note;
import org.apache.zeppelin.notebook.Paragraph;
@ -82,6 +84,57 @@ public class ZeppelinSparkClusterTest extends AbstractTestRestApi {
ZeppelinServer.notebook.removeNote(note.getId(), null);
}
@Test
public void sparkSQLTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote(null);
int sparkVersion = getSparkVersionNumber(note);
// DataFrame API is available from spark 1.3
if (sparkVersion >= 13) {
// test basic dataframe api
Paragraph p = note.addParagraph();
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%spark val df=sqlContext.createDataFrame(Seq((\"hello\",20)))\n" +
"df.collect()");
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertTrue(p.getResult().message().contains(
"Array[org.apache.spark.sql.Row] = Array([hello,20])"));
// test display DataFrame
p = note.addParagraph();
config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%spark val df=sqlContext.createDataFrame(Seq((\"hello\",20)))\n" +
"z.show(df)");
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TABLE, p.getResult().type());
assertEquals("_1\t_2\nhello\t20\n", p.getResult().message());
// test display DataSet
if (sparkVersion >= 20) {
p = note.addParagraph();
config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%spark val ds=spark.createDataset(Seq((\"hello\",20)))\n" +
"z.show(ds)");
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TABLE, p.getResult().type());
assertEquals("_1\t_2\nhello\t20\n", p.getResult().message());
}
ZeppelinServer.notebook.removeNote(note.getId(), null);
}
}
@Test
public void sparkRTest() throws IOException {
// create new note
@ -152,6 +205,21 @@ public class ZeppelinSparkClusterTest extends AbstractTestRestApi {
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals("[Row(age=20, id=1)]\n", p.getResult().message());
// test display Dataframe
p = note.addParagraph();
config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%pyspark from pyspark.sql import Row\n" +
"df=sqlContext.createDataFrame([Row(id=1, age=20)])\n" +
"z.show(df)");
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TABLE, p.getResult().type());
// TODO (zjffdu), one more \n is appended, need to investigate why.
assertEquals("age\tid\n20\t1\n\n", p.getResult().message());
}
if (sparkVersion >= 20) {
// run SparkSession test