mirror of
https://github.com/apache/zeppelin
synced 2026-05-24 09:38:26 +00:00
[ZEPPELIN-5479] %python.sql doesn't work with ipython interpreter
This commit is contained in:
parent
f9fb395942
commit
c9e63fa4e0
3 changed files with 116 additions and 71 deletions
|
|
@ -563,13 +563,23 @@ public class PythonInterpreter extends Interpreter {
|
|||
String bootstrapCode =
|
||||
IOUtils.toString(getClass().getClassLoader().getResourceAsStream(resourceName));
|
||||
try {
|
||||
// Add hook explicitly, otherwise python will fail to execute the statement
|
||||
InterpreterResult result = interpret(bootstrapCode + "\n" + "__zeppelin__._displayhook()",
|
||||
InterpreterContext.get());
|
||||
if (result.code() != Code.SUCCESS) {
|
||||
throw new IOException("Fail to run bootstrap script: " + resourceName + "\n" + result);
|
||||
if (iPythonInterpreter != null) {
|
||||
InterpreterResult result = iPythonInterpreter.interpret(bootstrapCode,
|
||||
InterpreterContext.get());
|
||||
if (result.code() != Code.SUCCESS) {
|
||||
throw new IOException("Fail to run bootstrap script: " + resourceName + "\n" + result);
|
||||
} else {
|
||||
LOGGER.debug("Bootstrap python successfully.");
|
||||
}
|
||||
} else {
|
||||
LOGGER.debug("Bootstrap python successfully.");
|
||||
// Add hook explicitly, otherwise python will fail to execute the statement
|
||||
InterpreterResult result = interpret(bootstrapCode + "\n" + "__zeppelin__._displayhook()",
|
||||
InterpreterContext.get());
|
||||
if (result.code() != Code.SUCCESS) {
|
||||
throw new IOException("Fail to run bootstrap script: " + resourceName + "\n" + result);
|
||||
} else {
|
||||
LOGGER.debug("Bootstrap python successfully.");
|
||||
}
|
||||
}
|
||||
} catch (InterpreterException e) {
|
||||
throw new IOException(e);
|
||||
|
|
|
|||
|
|
@ -30,12 +30,11 @@ import java.util.Properties;
|
|||
/**
|
||||
* SQL over Pandas DataFrame interpreter for %python group
|
||||
* <p>
|
||||
* Match experience of %sparpk.sql over Spark DataFrame
|
||||
* Match experience of %spark.sql over Spark DataFrame
|
||||
*/
|
||||
public class PythonInterpreterPandasSql extends Interpreter {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(PythonInterpreterPandasSql.class);
|
||||
|
||||
private String SQL_BOOTSTRAP_FILE_PY = "python/bootstrap_sql.py";
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(PythonInterpreterPandasSql.class);
|
||||
private static String SQL_BOOTSTRAP_FILE_PY = "python/bootstrap_sql.py";
|
||||
|
||||
private PythonInterpreter pythonInterpreter;
|
||||
|
||||
|
|
@ -45,20 +44,19 @@ public class PythonInterpreterPandasSql extends Interpreter {
|
|||
|
||||
@Override
|
||||
public void open() throws InterpreterException {
|
||||
LOG.info("Open Python SQL interpreter instance: {}", this.toString());
|
||||
|
||||
LOGGER.info("Open Python SQL interpreter instance: PythonInterpreterPandasSql");
|
||||
try {
|
||||
LOG.info("Bootstrap {} interpreter with {}", this.toString(), SQL_BOOTSTRAP_FILE_PY);
|
||||
LOGGER.info("Bootstrap PythonInterpreterPandasSql interpreter with {}", SQL_BOOTSTRAP_FILE_PY);
|
||||
this.pythonInterpreter = getInterpreterInTheSameSessionByClassName(PythonInterpreter.class);
|
||||
this.pythonInterpreter.bootstrapInterpreter(SQL_BOOTSTRAP_FILE_PY);
|
||||
} catch (IOException e) {
|
||||
LOG.error("Can't execute " + SQL_BOOTSTRAP_FILE_PY + " to import SQL dependencies", e);
|
||||
LOGGER.error("Can't execute " + SQL_BOOTSTRAP_FILE_PY + " to import SQL dependencies", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws InterpreterException {
|
||||
LOG.info("Close Python SQL interpreter instance: {}", this.toString());
|
||||
LOGGER.info("Close Python SQL interpreter instance: {}", this.toString());
|
||||
if (pythonInterpreter != null) {
|
||||
pythonInterpreter.close();
|
||||
}
|
||||
|
|
@ -67,14 +65,14 @@ public class PythonInterpreterPandasSql extends Interpreter {
|
|||
@Override
|
||||
public InterpreterResult interpret(String st, InterpreterContext context)
|
||||
throws InterpreterException {
|
||||
LOG.info("Running SQL query: '{}' over Pandas DataFrame", st);
|
||||
LOGGER.info("Running SQL query: '{}' over Pandas DataFrame", st);
|
||||
return pythonInterpreter.interpret(
|
||||
"__zeppelin__.show(pysqldf('" + st + "'))\n__zeppelin__._displayhook()", context);
|
||||
"z.show(pysqldf('" + st + "'))", context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cancel(InterpreterContext context) {
|
||||
|
||||
public void cancel(InterpreterContext context) throws InterpreterException {
|
||||
pythonInterpreter.cancel(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -83,8 +81,8 @@ public class PythonInterpreterPandasSql extends Interpreter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public int getProgress(InterpreterContext context) {
|
||||
return 0;
|
||||
public int getProgress(InterpreterContext context) throws InterpreterException {
|
||||
return pythonInterpreter.getProgress(context);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,20 +17,26 @@
|
|||
|
||||
package org.apache.zeppelin.python;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.apache.zeppelin.interpreter.Interpreter;
|
||||
import org.apache.zeppelin.interpreter.InterpreterContext;
|
||||
import org.apache.zeppelin.interpreter.InterpreterException;
|
||||
import org.apache.zeppelin.interpreter.InterpreterGroup;
|
||||
import org.apache.zeppelin.interpreter.InterpreterOutput;
|
||||
import org.apache.zeppelin.interpreter.InterpreterOutputListener;
|
||||
import org.apache.zeppelin.interpreter.InterpreterResult;
|
||||
import org.apache.zeppelin.interpreter.InterpreterResult.Type;
|
||||
import org.apache.zeppelin.interpreter.InterpreterResultMessageOutput;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Properties;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
|
@ -51,64 +57,86 @@ import static org.junit.Assert.assertTrue;
|
|||
* mvn -Dpython.test.exclude='' test -pl python -am
|
||||
* </code>
|
||||
*/
|
||||
public class PythonInterpreterPandasSqlTest implements InterpreterOutputListener {
|
||||
@RunWith(value = Parameterized.class)
|
||||
public class PythonInterpreterPandasSqlTest {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(PythonInterpreterPandasSqlTest.class);
|
||||
|
||||
@Parameterized.Parameters
|
||||
public static List<Object[]> data() {
|
||||
return Arrays.asList(new Object[][]{
|
||||
{true},
|
||||
{false}
|
||||
});
|
||||
}
|
||||
|
||||
private boolean useIPython;
|
||||
private InterpreterGroup intpGroup;
|
||||
private PythonInterpreterPandasSql sql;
|
||||
private PythonInterpreter python;
|
||||
private PythonInterpreterPandasSql pandasSqlInterpreter;
|
||||
private PythonInterpreter pythonInterpreter;
|
||||
private IPythonInterpreter ipythonInterpreter;
|
||||
|
||||
private InterpreterContext context;
|
||||
InterpreterOutput out;
|
||||
private InterpreterOutput out;
|
||||
|
||||
public PythonInterpreterPandasSqlTest(boolean useIPython) {
|
||||
this.useIPython = useIPython;
|
||||
LOGGER.info("Test PythonInterpreterPandasSqlTest while useIPython={}", useIPython);
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
Properties p = new Properties();
|
||||
p.setProperty("zeppelin.python", "python");
|
||||
p.setProperty("zeppelin.python.maxResult", "100");
|
||||
p.setProperty("zeppelin.python.useIPython", "false");
|
||||
p.setProperty("zeppelin.python.useIPython", useIPython + "");
|
||||
p.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1");
|
||||
|
||||
intpGroup = new InterpreterGroup();
|
||||
|
||||
out = new InterpreterOutput(this);
|
||||
out = new InterpreterOutput();
|
||||
context = InterpreterContext.builder()
|
||||
.setInterpreterOut(out)
|
||||
.build();
|
||||
InterpreterContext.set(context);
|
||||
|
||||
python = new PythonInterpreter(p);
|
||||
python.setInterpreterGroup(intpGroup);
|
||||
python.open();
|
||||
pythonInterpreter = new PythonInterpreter(p);
|
||||
ipythonInterpreter = new IPythonInterpreter(p);
|
||||
pandasSqlInterpreter = new PythonInterpreterPandasSql(p);
|
||||
|
||||
sql = new PythonInterpreterPandasSql(p);
|
||||
sql.setInterpreterGroup(intpGroup);
|
||||
pythonInterpreter.setInterpreterGroup(intpGroup);
|
||||
ipythonInterpreter.setInterpreterGroup(intpGroup);
|
||||
pandasSqlInterpreter.setInterpreterGroup(intpGroup);
|
||||
|
||||
intpGroup.put("note", Arrays.asList(python, sql));
|
||||
List<Interpreter> interpreters =
|
||||
Lists.newArrayList(pythonInterpreter, ipythonInterpreter, pandasSqlInterpreter);
|
||||
|
||||
intpGroup.put("session_1", interpreters);
|
||||
|
||||
pythonInterpreter.open();
|
||||
|
||||
// to make sure python is running.
|
||||
InterpreterResult ret = python.interpret("print(\"python initialized\")\n", context);
|
||||
InterpreterResult ret = pythonInterpreter.interpret("print(\"python initialized\")\n", context);
|
||||
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
|
||||
|
||||
sql.open();
|
||||
pandasSqlInterpreter.open();
|
||||
}
|
||||
|
||||
@After
|
||||
public void afterTest() throws IOException, InterpreterException {
|
||||
sql.close();
|
||||
pandasSqlInterpreter.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void dependenciesAreInstalled() throws InterpreterException {
|
||||
InterpreterResult ret =
|
||||
python.interpret("import pandas\nimport pandasql\nimport numpy\n", context);
|
||||
pythonInterpreter.interpret("import pandas\nimport pandasql\nimport numpy\n", context);
|
||||
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void errorMessageIfDependenciesNotInstalled() throws InterpreterException {
|
||||
InterpreterResult ret;
|
||||
ret = sql.interpret("SELECT * from something", context);
|
||||
ret = pandasSqlInterpreter.interpret("SELECT * from something", context);
|
||||
|
||||
assertNotNull(ret);
|
||||
assertEquals(ret.message().get(0).getData(), InterpreterResult.Code.ERROR, ret.code());
|
||||
|
|
@ -117,18 +145,16 @@ public class PythonInterpreterPandasSqlTest implements InterpreterOutputListener
|
|||
|
||||
@Test
|
||||
public void sqlOverTestDataPrintsTable() throws IOException, InterpreterException {
|
||||
InterpreterResult ret;
|
||||
// given
|
||||
//String expectedTable = "name\tage\n\nmoon\t33\n\npark\t34";
|
||||
ret = python.interpret("import pandas as pd", context);
|
||||
ret = python.interpret("import numpy as np", context);
|
||||
InterpreterResult ret = pythonInterpreter.interpret("import pandas as pd\nimport numpy as np", context);
|
||||
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
|
||||
|
||||
// DataFrame df2 \w test data
|
||||
ret = python.interpret("df2 = pd.DataFrame({ 'age' : np.array([33, 51, 51, 34]), " +
|
||||
ret = pythonInterpreter.interpret("df2 = pd.DataFrame({ 'age' : np.array([33, 51, 51, 34]), " +
|
||||
"'name' : pd.Categorical(['moon','jobs','gates','park'])})", context);
|
||||
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
|
||||
|
||||
//when
|
||||
ret = sql.interpret("select name, age from df2 where age < 40", context);
|
||||
ret = pandasSqlInterpreter.interpret("select name, age from df2 where age < 40", context);
|
||||
|
||||
//then
|
||||
assertEquals(new String(out.getOutputAt(1).toByteArray()),
|
||||
|
|
@ -139,14 +165,40 @@ public class PythonInterpreterPandasSqlTest implements InterpreterOutputListener
|
|||
assertTrue(new String(out.getOutputAt(1).toByteArray()).indexOf("park\t34") > 0);
|
||||
|
||||
assertEquals(InterpreterResult.Code.SUCCESS,
|
||||
sql.interpret("select case when name==\"aa\" then name else name end from df2",
|
||||
pandasSqlInterpreter.interpret("select case when name==\"aa\" then name else name end from df2",
|
||||
context).code());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInIPython() throws IOException, InterpreterException {
|
||||
InterpreterResult ret =
|
||||
pythonInterpreter.interpret("import pandas as pd\nimport numpy as np", context);
|
||||
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
|
||||
// DataFrame df2 \w test data
|
||||
ret = pythonInterpreter.interpret("df2 = pd.DataFrame({ 'age' : np.array([33, 51, 51, 34]), " +
|
||||
"'name' : pd.Categorical(['moon','jobs','gates','park'])})", context);
|
||||
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
|
||||
|
||||
//when
|
||||
ret = pandasSqlInterpreter.interpret("select name, age from df2 where age < 40", context);
|
||||
|
||||
//then
|
||||
assertEquals(new String(out.getOutputAt(1).toByteArray()),
|
||||
InterpreterResult.Code.SUCCESS, ret.code());
|
||||
assertEquals(new String(out.getOutputAt(1).toByteArray()), Type.TABLE,
|
||||
out.getOutputAt(1).getType());
|
||||
assertTrue(new String(out.getOutputAt(1).toByteArray()).indexOf("moon\t33") > 0);
|
||||
assertTrue(new String(out.getOutputAt(1).toByteArray()).indexOf("park\t34") > 0);
|
||||
|
||||
assertEquals(InterpreterResult.Code.SUCCESS,
|
||||
pandasSqlInterpreter.interpret("select case when name==\"aa\" then name else name end from df2",
|
||||
context).code());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void badSqlSyntaxFails() throws IOException, InterpreterException {
|
||||
//when
|
||||
InterpreterResult ret = sql.interpret("select wrong syntax", context);
|
||||
InterpreterResult ret = pandasSqlInterpreter.interpret("select wrong syntax", context);
|
||||
|
||||
//then
|
||||
assertNotNull("Interpreter returned 'null'", ret);
|
||||
|
|
@ -156,17 +208,17 @@ public class PythonInterpreterPandasSqlTest implements InterpreterOutputListener
|
|||
@Test
|
||||
public void showDataFrame() throws IOException, InterpreterException {
|
||||
InterpreterResult ret;
|
||||
ret = python.interpret("import pandas as pd", context);
|
||||
ret = python.interpret("import numpy as np", context);
|
||||
ret = pythonInterpreter.interpret("import pandas as pd", context);
|
||||
ret = pythonInterpreter.interpret("import numpy as np", context);
|
||||
|
||||
// given a Pandas DataFrame with an index and non-text data
|
||||
ret = python.interpret("index = pd.Index([10, 11, 12, 13], name='index_name')", context);
|
||||
ret = python.interpret("d1 = {1 : [np.nan, 1, 2, 3], 'two' : [3., 4., 5., 6.7]}", context);
|
||||
ret = python.interpret("df1 = pd.DataFrame(d1, index=index)", context);
|
||||
ret = pythonInterpreter.interpret("index = pd.Index([10, 11, 12, 13], name='index_name')", context);
|
||||
ret = pythonInterpreter.interpret("d1 = {1 : [np.nan, 1, 2, 3], 'two' : [3., 4., 5., 6.7]}", context);
|
||||
ret = pythonInterpreter.interpret("df1 = pd.DataFrame(d1, index=index)", context);
|
||||
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
|
||||
|
||||
// when
|
||||
ret = python.interpret("z.show(df1, show_index=True)", context);
|
||||
ret = pythonInterpreter.interpret("z.show(df1, show_index=True)", context);
|
||||
|
||||
// then
|
||||
assertEquals(new String(out.getOutputAt(0).toByteArray()),
|
||||
|
|
@ -177,19 +229,4 @@ public class PythonInterpreterPandasSqlTest implements InterpreterOutputListener
|
|||
assertTrue(new String(out.getOutputAt(1).toByteArray()).contains("nan"));
|
||||
assertTrue(new String(out.getOutputAt(1).toByteArray()).contains("6.7"));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onUpdateAll(InterpreterOutput out) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onAppend(int index, InterpreterResultMessageOutput out, byte[] line) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onUpdate(int index, InterpreterResultMessageOutput out) {
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue