[ZEPPELIN-5479] %python.sql doesn't work with ipython interpreter

This commit is contained in:
Jeff Zhang 2021-08-07 13:13:35 +08:00
parent f9fb395942
commit c9e63fa4e0
3 changed files with 116 additions and 71 deletions

View file

@ -563,13 +563,23 @@ public class PythonInterpreter extends Interpreter {
String bootstrapCode =
IOUtils.toString(getClass().getClassLoader().getResourceAsStream(resourceName));
try {
// Add hook explicitly, otherwise python will fail to execute the statement
InterpreterResult result = interpret(bootstrapCode + "\n" + "__zeppelin__._displayhook()",
InterpreterContext.get());
if (result.code() != Code.SUCCESS) {
throw new IOException("Fail to run bootstrap script: " + resourceName + "\n" + result);
if (iPythonInterpreter != null) {
InterpreterResult result = iPythonInterpreter.interpret(bootstrapCode,
InterpreterContext.get());
if (result.code() != Code.SUCCESS) {
throw new IOException("Fail to run bootstrap script: " + resourceName + "\n" + result);
} else {
LOGGER.debug("Bootstrap python successfully.");
}
} else {
LOGGER.debug("Bootstrap python successfully.");
// Add hook explicitly, otherwise python will fail to execute the statement
InterpreterResult result = interpret(bootstrapCode + "\n" + "__zeppelin__._displayhook()",
InterpreterContext.get());
if (result.code() != Code.SUCCESS) {
throw new IOException("Fail to run bootstrap script: " + resourceName + "\n" + result);
} else {
LOGGER.debug("Bootstrap python successfully.");
}
}
} catch (InterpreterException e) {
throw new IOException(e);

View file

@ -30,12 +30,11 @@ import java.util.Properties;
/**
* SQL over Pandas DataFrame interpreter for %python group
* <p>
* Match experience of %sparpk.sql over Spark DataFrame
* Match experience of %spark.sql over Spark DataFrame
*/
public class PythonInterpreterPandasSql extends Interpreter {
private static final Logger LOG = LoggerFactory.getLogger(PythonInterpreterPandasSql.class);
private String SQL_BOOTSTRAP_FILE_PY = "python/bootstrap_sql.py";
private static final Logger LOGGER = LoggerFactory.getLogger(PythonInterpreterPandasSql.class);
private static String SQL_BOOTSTRAP_FILE_PY = "python/bootstrap_sql.py";
private PythonInterpreter pythonInterpreter;
@ -45,20 +44,19 @@ public class PythonInterpreterPandasSql extends Interpreter {
@Override
public void open() throws InterpreterException {
LOG.info("Open Python SQL interpreter instance: {}", this.toString());
LOGGER.info("Open Python SQL interpreter instance: PythonInterpreterPandasSql");
try {
LOG.info("Bootstrap {} interpreter with {}", this.toString(), SQL_BOOTSTRAP_FILE_PY);
LOGGER.info("Bootstrap PythonInterpreterPandasSql interpreter with {}", SQL_BOOTSTRAP_FILE_PY);
this.pythonInterpreter = getInterpreterInTheSameSessionByClassName(PythonInterpreter.class);
this.pythonInterpreter.bootstrapInterpreter(SQL_BOOTSTRAP_FILE_PY);
} catch (IOException e) {
LOG.error("Can't execute " + SQL_BOOTSTRAP_FILE_PY + " to import SQL dependencies", e);
LOGGER.error("Can't execute " + SQL_BOOTSTRAP_FILE_PY + " to import SQL dependencies", e);
}
}
@Override
public void close() throws InterpreterException {
LOG.info("Close Python SQL interpreter instance: {}", this.toString());
LOGGER.info("Close Python SQL interpreter instance: {}", this.toString());
if (pythonInterpreter != null) {
pythonInterpreter.close();
}
@ -67,14 +65,14 @@ public class PythonInterpreterPandasSql extends Interpreter {
@Override
public InterpreterResult interpret(String st, InterpreterContext context)
throws InterpreterException {
LOG.info("Running SQL query: '{}' over Pandas DataFrame", st);
LOGGER.info("Running SQL query: '{}' over Pandas DataFrame", st);
return pythonInterpreter.interpret(
"__zeppelin__.show(pysqldf('" + st + "'))\n__zeppelin__._displayhook()", context);
"z.show(pysqldf('" + st + "'))", context);
}
@Override
public void cancel(InterpreterContext context) {
public void cancel(InterpreterContext context) throws InterpreterException {
pythonInterpreter.cancel(context);
}
@Override
@ -83,8 +81,8 @@ public class PythonInterpreterPandasSql extends Interpreter {
}
@Override
public int getProgress(InterpreterContext context) {
return 0;
public int getProgress(InterpreterContext context) throws InterpreterException {
return pythonInterpreter.getProgress(context);
}
}

View file

@ -17,20 +17,26 @@
package org.apache.zeppelin.python;
import com.google.common.collect.Lists;
import org.apache.zeppelin.interpreter.Interpreter;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterException;
import org.apache.zeppelin.interpreter.InterpreterGroup;
import org.apache.zeppelin.interpreter.InterpreterOutput;
import org.apache.zeppelin.interpreter.InterpreterOutputListener;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterResult.Type;
import org.apache.zeppelin.interpreter.InterpreterResultMessageOutput;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Properties;
import static org.junit.Assert.assertEquals;
@ -51,64 +57,86 @@ import static org.junit.Assert.assertTrue;
* mvn -Dpython.test.exclude='' test -pl python -am
* </code>
*/
public class PythonInterpreterPandasSqlTest implements InterpreterOutputListener {
@RunWith(value = Parameterized.class)
public class PythonInterpreterPandasSqlTest {
private static final Logger LOGGER = LoggerFactory.getLogger(PythonInterpreterPandasSqlTest.class);
@Parameterized.Parameters
public static List<Object[]> data() {
return Arrays.asList(new Object[][]{
{true},
{false}
});
}
private boolean useIPython;
private InterpreterGroup intpGroup;
private PythonInterpreterPandasSql sql;
private PythonInterpreter python;
private PythonInterpreterPandasSql pandasSqlInterpreter;
private PythonInterpreter pythonInterpreter;
private IPythonInterpreter ipythonInterpreter;
private InterpreterContext context;
InterpreterOutput out;
private InterpreterOutput out;
public PythonInterpreterPandasSqlTest(boolean useIPython) {
this.useIPython = useIPython;
LOGGER.info("Test PythonInterpreterPandasSqlTest while useIPython={}", useIPython);
}
@Before
public void setUp() throws Exception {
Properties p = new Properties();
p.setProperty("zeppelin.python", "python");
p.setProperty("zeppelin.python.maxResult", "100");
p.setProperty("zeppelin.python.useIPython", "false");
p.setProperty("zeppelin.python.useIPython", useIPython + "");
p.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1");
intpGroup = new InterpreterGroup();
out = new InterpreterOutput(this);
out = new InterpreterOutput();
context = InterpreterContext.builder()
.setInterpreterOut(out)
.build();
InterpreterContext.set(context);
python = new PythonInterpreter(p);
python.setInterpreterGroup(intpGroup);
python.open();
pythonInterpreter = new PythonInterpreter(p);
ipythonInterpreter = new IPythonInterpreter(p);
pandasSqlInterpreter = new PythonInterpreterPandasSql(p);
sql = new PythonInterpreterPandasSql(p);
sql.setInterpreterGroup(intpGroup);
pythonInterpreter.setInterpreterGroup(intpGroup);
ipythonInterpreter.setInterpreterGroup(intpGroup);
pandasSqlInterpreter.setInterpreterGroup(intpGroup);
intpGroup.put("note", Arrays.asList(python, sql));
List<Interpreter> interpreters =
Lists.newArrayList(pythonInterpreter, ipythonInterpreter, pandasSqlInterpreter);
intpGroup.put("session_1", interpreters);
pythonInterpreter.open();
// to make sure python is running.
InterpreterResult ret = python.interpret("print(\"python initialized\")\n", context);
InterpreterResult ret = pythonInterpreter.interpret("print(\"python initialized\")\n", context);
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
sql.open();
pandasSqlInterpreter.open();
}
@After
public void afterTest() throws IOException, InterpreterException {
sql.close();
pandasSqlInterpreter.close();
}
@Test
public void dependenciesAreInstalled() throws InterpreterException {
InterpreterResult ret =
python.interpret("import pandas\nimport pandasql\nimport numpy\n", context);
pythonInterpreter.interpret("import pandas\nimport pandasql\nimport numpy\n", context);
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
}
@Test
public void errorMessageIfDependenciesNotInstalled() throws InterpreterException {
InterpreterResult ret;
ret = sql.interpret("SELECT * from something", context);
ret = pandasSqlInterpreter.interpret("SELECT * from something", context);
assertNotNull(ret);
assertEquals(ret.message().get(0).getData(), InterpreterResult.Code.ERROR, ret.code());
@ -117,18 +145,16 @@ public class PythonInterpreterPandasSqlTest implements InterpreterOutputListener
@Test
public void sqlOverTestDataPrintsTable() throws IOException, InterpreterException {
InterpreterResult ret;
// given
//String expectedTable = "name\tage\n\nmoon\t33\n\npark\t34";
ret = python.interpret("import pandas as pd", context);
ret = python.interpret("import numpy as np", context);
InterpreterResult ret = pythonInterpreter.interpret("import pandas as pd\nimport numpy as np", context);
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
// DataFrame df2 \w test data
ret = python.interpret("df2 = pd.DataFrame({ 'age' : np.array([33, 51, 51, 34]), " +
ret = pythonInterpreter.interpret("df2 = pd.DataFrame({ 'age' : np.array([33, 51, 51, 34]), " +
"'name' : pd.Categorical(['moon','jobs','gates','park'])})", context);
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
//when
ret = sql.interpret("select name, age from df2 where age < 40", context);
ret = pandasSqlInterpreter.interpret("select name, age from df2 where age < 40", context);
//then
assertEquals(new String(out.getOutputAt(1).toByteArray()),
@ -139,14 +165,40 @@ public class PythonInterpreterPandasSqlTest implements InterpreterOutputListener
assertTrue(new String(out.getOutputAt(1).toByteArray()).indexOf("park\t34") > 0);
assertEquals(InterpreterResult.Code.SUCCESS,
sql.interpret("select case when name==\"aa\" then name else name end from df2",
pandasSqlInterpreter.interpret("select case when name==\"aa\" then name else name end from df2",
context).code());
}
@Test
public void testInIPython() throws IOException, InterpreterException {
InterpreterResult ret =
pythonInterpreter.interpret("import pandas as pd\nimport numpy as np", context);
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
// DataFrame df2 \w test data
ret = pythonInterpreter.interpret("df2 = pd.DataFrame({ 'age' : np.array([33, 51, 51, 34]), " +
"'name' : pd.Categorical(['moon','jobs','gates','park'])})", context);
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
//when
ret = pandasSqlInterpreter.interpret("select name, age from df2 where age < 40", context);
//then
assertEquals(new String(out.getOutputAt(1).toByteArray()),
InterpreterResult.Code.SUCCESS, ret.code());
assertEquals(new String(out.getOutputAt(1).toByteArray()), Type.TABLE,
out.getOutputAt(1).getType());
assertTrue(new String(out.getOutputAt(1).toByteArray()).indexOf("moon\t33") > 0);
assertTrue(new String(out.getOutputAt(1).toByteArray()).indexOf("park\t34") > 0);
assertEquals(InterpreterResult.Code.SUCCESS,
pandasSqlInterpreter.interpret("select case when name==\"aa\" then name else name end from df2",
context).code());
}
@Test
public void badSqlSyntaxFails() throws IOException, InterpreterException {
//when
InterpreterResult ret = sql.interpret("select wrong syntax", context);
InterpreterResult ret = pandasSqlInterpreter.interpret("select wrong syntax", context);
//then
assertNotNull("Interpreter returned 'null'", ret);
@ -156,17 +208,17 @@ public class PythonInterpreterPandasSqlTest implements InterpreterOutputListener
@Test
public void showDataFrame() throws IOException, InterpreterException {
InterpreterResult ret;
ret = python.interpret("import pandas as pd", context);
ret = python.interpret("import numpy as np", context);
ret = pythonInterpreter.interpret("import pandas as pd", context);
ret = pythonInterpreter.interpret("import numpy as np", context);
// given a Pandas DataFrame with an index and non-text data
ret = python.interpret("index = pd.Index([10, 11, 12, 13], name='index_name')", context);
ret = python.interpret("d1 = {1 : [np.nan, 1, 2, 3], 'two' : [3., 4., 5., 6.7]}", context);
ret = python.interpret("df1 = pd.DataFrame(d1, index=index)", context);
ret = pythonInterpreter.interpret("index = pd.Index([10, 11, 12, 13], name='index_name')", context);
ret = pythonInterpreter.interpret("d1 = {1 : [np.nan, 1, 2, 3], 'two' : [3., 4., 5., 6.7]}", context);
ret = pythonInterpreter.interpret("df1 = pd.DataFrame(d1, index=index)", context);
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
// when
ret = python.interpret("z.show(df1, show_index=True)", context);
ret = pythonInterpreter.interpret("z.show(df1, show_index=True)", context);
// then
assertEquals(new String(out.getOutputAt(0).toByteArray()),
@ -177,19 +229,4 @@ public class PythonInterpreterPandasSqlTest implements InterpreterOutputListener
assertTrue(new String(out.getOutputAt(1).toByteArray()).contains("nan"));
assertTrue(new String(out.getOutputAt(1).toByteArray()).contains("6.7"));
}
@Override
public void onUpdateAll(InterpreterOutput out) {
}
@Override
public void onAppend(int index, InterpreterResultMessageOutput out, byte[] line) {
}
@Override
public void onUpdate(int index, InterpreterResultMessageOutput out) {
}
}