ZEPPELIN-3375. Make PySparkInterpreter extends PythonInterpreter

This commit is contained in:
Jeff Zhang 2018-04-01 21:26:28 +08:00
parent 7a67138a9e
commit 738c6c5140
29 changed files with 1282 additions and 1900 deletions

View file

@ -40,8 +40,6 @@
**/PythonInterpreterPandasSqlTest.java,
**/PythonInterpreterMatplotlibTest.java
</python.test.exclude>
<pypi.repo.url>https://pypi.python.org/packages</pypi.repo.url>
<python.py4j.repo.folder>/64/5c/01e13b68e8caafece40d549f232c9b5677ad1016071a48d04cc3895acaa3</python.py4j.repo.folder>
<grpc.version>1.4.0</grpc.version>
<plugin.shade.version>2.4.1</plugin.shade.version>
</properties>
@ -137,35 +135,12 @@
</executions>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>wagon-maven-plugin</artifactId>
<version>1.0</version>
<executions>
<execution>
<phase>package</phase>
<goals><goal>download-single</goal></goals>
<configuration>
<url>${pypi.repo.url}${python.py4j.repo.folder}</url>
<fromFile>py4j-${python.py4j.version}.zip</fromFile>
<toFile>${project.build.directory}/../../interpreter/python/py4j-${python.py4j.version}.zip</toFile>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-antrun-plugin</artifactId>
<version>1.7</version>
<executions>
<execution>
<phase>package</phase>
<configuration>
<target>
<unzip src="${project.build.directory}/../../interpreter/python/py4j-${python.py4j.version}.zip"
dest="${project.build.directory}/../../interpreter/python"/>
</target>
</configuration>
<goals>
<goal>run</goal>
</goals>

View file

@ -243,16 +243,21 @@ public class IPythonInterpreter extends Interpreter implements ExecuteResultHand
private void launchIPythonKernel(int ipythonPort)
throws IOException, URISyntaxException {
// copy the python scripts to a temp directory, then launch ipython kernel in that folder
File tmpPythonScriptFolder = Files.createTempDirectory("zeppelin_ipython").toFile();
File pythonWorkDir = Files.createTempDirectory("zeppelin_ipython").toFile();
String[] ipythonScripts = {"ipython_server.py", "ipython_pb2.py", "ipython_pb2_grpc.py"};
for (String ipythonScript : ipythonScripts) {
URL url = getClass().getClassLoader().getResource("grpc/python"
+ "/" + ipythonScript);
FileUtils.copyURLToFile(url, new File(tmpPythonScriptFolder, ipythonScript));
FileUtils.copyURLToFile(url, new File(pythonWorkDir, ipythonScript));
}
//TODO(zjffdu) don't do hard code on py4j here
File py4jDestFile = new File(pythonWorkDir, "py4j-src-0.9.2.zip");
FileUtils.copyURLToFile(getClass().getClassLoader().getResource(
"python/py4j-src-0.9.2.zip"), py4jDestFile);
CommandLine cmd = CommandLine.parse(pythonExecutable);
cmd.addArgument(tmpPythonScriptFolder.getAbsolutePath() + "/ipython_server.py");
cmd.addArgument(pythonWorkDir.getAbsolutePath() + "/ipython_server.py");
cmd.addArgument(ipythonPort + "");
DefaultExecutor executor = new DefaultExecutor();
ProcessLogOutputStream processOutput = new ProcessLogOutputStream(LOGGER);
@ -261,20 +266,12 @@ public class IPythonInterpreter extends Interpreter implements ExecuteResultHand
executor.setWatchdog(watchDog);
if (useBuiltinPy4j) {
String py4jLibPath = null;
if (System.getenv("ZEPPELIN_HOME") != null) {
py4jLibPath = System.getenv("ZEPPELIN_HOME") + File.separator
+ PythonInterpreter.ZEPPELIN_PY4JPATH;
} else {
Path workingPath = Paths.get("..").toAbsolutePath();
py4jLibPath = workingPath + File.separator + PythonInterpreter.ZEPPELIN_PY4JPATH;
}
if (additionalPythonPath != null) {
// put the py4j at the end, because additionalPythonPath may already contain py4j.
// e.g. PySparkInterpreter
additionalPythonPath = additionalPythonPath + ":" + py4jLibPath;
additionalPythonPath = additionalPythonPath + ":" + py4jDestFile.getAbsolutePath();
} else {
additionalPythonPath = py4jLibPath;
additionalPythonPath = py4jDestFile.getAbsolutePath();
}
}
@ -326,7 +323,7 @@ public class IPythonInterpreter extends Interpreter implements ExecuteResultHand
@Override
public void close() throws InterpreterException {
if (watchDog != null) {
LOGGER.debug("Kill IPython Process");
LOGGER.info("Kill IPython Process");
ipythonClient.stop(StopRequest.newBuilder().build());
watchDog.destroyProcess();
gatewayServer.shutdown();

View file

@ -31,9 +31,10 @@ import java.util.regex.Pattern;
/**
* Conda support
* TODO(zjffdu) Add removing conda env
*/
public class PythonCondaInterpreter extends Interpreter {
Logger logger = LoggerFactory.getLogger(PythonCondaInterpreter.class);
private static Logger logger = LoggerFactory.getLogger(PythonCondaInterpreter.class);
public static final String ZEPPELIN_PYTHON = "zeppelin.python";
public static final String CONDA_PYTHON_PATH = "/bin/python";
public static final String DEFAULT_ZEPPELIN_PYTHON = "python";
@ -145,33 +146,22 @@ public class PythonCondaInterpreter extends Interpreter {
}
}
setCurrentCondaEnvName(envName);
python.setPythonCommand(binPath);
python.setPythonExec(binPath);
}
private void restartPythonProcess() throws InterpreterException {
PythonInterpreter python = getPythonInterpreter();
logger.debug("Restarting PythonInterpreter");
Interpreter python =
getInterpreterInTheSameSessionByClassName(PythonInterpreter.class.getName());
python.close();
python.open();
}
protected PythonInterpreter getPythonInterpreter() throws InterpreterException {
LazyOpenInterpreter lazy = null;
PythonInterpreter python = null;
Interpreter p =
getInterpreterInTheSameSessionByClassName(PythonInterpreter.class.getName());
while (p instanceof WrappedInterpreter) {
if (p instanceof LazyOpenInterpreter) {
lazy = (LazyOpenInterpreter) p;
}
p = ((WrappedInterpreter) p).getInnerInterpreter();
}
python = (PythonInterpreter) p;
if (lazy != null) {
lazy.open();
}
return python;
return (PythonInterpreter) ((LazyOpenInterpreter)p).getInnerInterpreter();
}
public static String runCondaCommandForTextOutput(String title, List<String> commands)
@ -392,27 +382,50 @@ public class PythonCondaInterpreter extends Interpreter {
public static String runCommand(List<String> commands)
throws IOException, InterruptedException {
StringBuilder sb = new StringBuilder();
ProcessBuilder builder = new ProcessBuilder(commands);
builder.redirectErrorStream(true);
Process process = builder.start();
InputStream stdout = process.getInputStream();
BufferedReader br = new BufferedReader(new InputStreamReader(stdout));
String line;
while ((line = br.readLine()) != null) {
sb.append(line);
sb.append("\n");
logger.info("Starting shell commands: " + StringUtils.join(commands, " "));
Process process = Runtime.getRuntime().exec(commands.toArray(new String[0]));
StreamGobbler errorGobbler = new StreamGobbler(process.getErrorStream());
StreamGobbler outputGobbler = new StreamGobbler(process.getInputStream());
errorGobbler.start();
outputGobbler.start();
if (process.waitFor() != 0) {
throw new IOException("Fail to run shell commands: " + StringUtils.join(commands, " "));
}
int r = process.waitFor(); // Let the process finish.
logger.info("Complete shell commands: " + StringUtils.join(commands, " "));
return outputGobbler.getOutput();
}
if (r != 0) {
throw new RuntimeException("Failed to execute `" +
StringUtils.join(commands, " ") + "` exited with " + r);
private static class StreamGobbler extends Thread {
InputStream is;
StringBuilder output = new StringBuilder();
// reads everything from is until empty.
StreamGobbler(InputStream is) {
this.is = is;
}
return sb.toString();
public void run() {
try {
InputStreamReader isr = new InputStreamReader(is);
BufferedReader br = new BufferedReader(isr);
String line = null;
long startTime = System.currentTimeMillis();
while ( (line = br.readLine()) != null) {
output.append(line + "\n");
// logging per 5 seconds
if ((System.currentTimeMillis() - startTime) > 5000) {
logger.info(line);
startTime = System.currentTimeMillis();
}
}
} catch (IOException ioe) {
ioe.printStackTrace();
}
}
public String getOutput() {
return output.toString();
}
}
public static String runCommand(String ... command)

View file

@ -58,7 +58,7 @@ public class PythonDockerInterpreter extends Interpreter {
@Override
public InterpreterResult interpret(String st, InterpreterContext context)
throws InterpreterException {
File pythonScript = new File(getPythonInterpreter().getScriptPath());
File pythonWorkDir = getPythonInterpreter().getPythonWorkDir();
InterpreterOutput out = context.out;
Matcher activateMatcher = activatePattern.matcher(st);
@ -73,26 +73,23 @@ public class PythonDockerInterpreter extends Interpreter {
pull(out, image);
// mount pythonscript dir
String mountPythonScript = "-v " +
pythonScript.getParentFile().getAbsolutePath() +
":/_zeppelin_tmp ";
String mountPythonScript = "-v " + pythonWorkDir.getAbsolutePath() +
":/_python_workdir ";
// mount zeppelin dir
String mountPy4j = "-v " +
zeppelinHome.getAbsolutePath() +
String mountPy4j = "-v " + zeppelinHome.getAbsolutePath() +
":/_zeppelin ";
// set PYTHONPATH
String pythonPath = ":/_zeppelin/" + PythonInterpreter.ZEPPELIN_PY4JPATH + ":" +
":/_zeppelin/" + PythonInterpreter.ZEPPELIN_PYTHON_LIBS;
String pythonPath = ".:/_python_workdir/py4j-src-0.9.2.zip:/_python_workdir";
setPythonCommand("docker run -i --rm " +
mountPythonScript +
mountPy4j +
"-e PYTHONPATH=\"" + pythonPath + "\" " +
image + " " +
getPythonInterpreter().getPythonBindPath() + " " +
"/_zeppelin_tmp/" + pythonScript.getName());
getPythonInterpreter().getPythonExec() + " " +
"/_python_workdir/zeppelin_python.py");
restartPythonProcess();
out.clear();
return new InterpreterResult(InterpreterResult.Code.SUCCESS, "\"" + image + "\" activated");
@ -108,7 +105,7 @@ public class PythonDockerInterpreter extends Interpreter {
public void setPythonCommand(String cmd) throws InterpreterException {
PythonInterpreter python = getPythonInterpreter();
python.setPythonCommand(cmd);
python.setPythonExec(cmd);
}
private void printUsage(InterpreterOutput out) {

View file

@ -1,41 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zeppelin.python;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.net.*;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.regex.Pattern;
import com.google.common.io.Files;
import com.google.gson.Gson;
import org.apache.commons.exec.CommandLine;
import org.apache.commons.exec.DefaultExecutor;
import org.apache.commons.exec.ExecuteException;
@ -45,239 +28,233 @@ import org.apache.commons.exec.PumpStreamHandler;
import org.apache.commons.exec.environment.EnvironmentUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.zeppelin.display.GUI;
import org.apache.zeppelin.interpreter.*;
import org.apache.zeppelin.interpreter.InterpreterResult.Code;
import org.apache.zeppelin.interpreter.BaseZeppelinContext;
import org.apache.zeppelin.interpreter.Interpreter;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterException;
import org.apache.zeppelin.interpreter.InterpreterGroup;
import org.apache.zeppelin.interpreter.InterpreterHookRegistry.HookType;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterResult.Code;
import org.apache.zeppelin.interpreter.InterpreterResultMessage;
import org.apache.zeppelin.interpreter.InvalidHookException;
import org.apache.zeppelin.interpreter.LazyOpenInterpreter;
import org.apache.zeppelin.interpreter.WrappedInterpreter;
import org.apache.zeppelin.interpreter.remote.RemoteInterpreterUtils;
import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
import org.apache.zeppelin.interpreter.util.InterpreterOutputStream;
import org.apache.zeppelin.scheduler.Job;
import org.apache.zeppelin.scheduler.Scheduler;
import org.apache.zeppelin.scheduler.SchedulerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import py4j.GatewayServer;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* Python interpreter for Zeppelin.
* Interpreter for Python, it is the first implementation of interpreter for Python, so with less
* features compared to IPythonInterpreter, but requires less prerequisites than
* IPythonInterpreter, only python installation is required.
*/
public class PythonInterpreter extends Interpreter implements ExecuteResultHandler {
private static final Logger LOG = LoggerFactory.getLogger(PythonInterpreter.class);
public static final String ZEPPELIN_PYTHON = "python/zeppelin_python.py";
public static final String ZEPPELIN_CONTEXT = "python/zeppelin_context.py";
public static final String ZEPPELIN_PY4JPATH = "interpreter/python/py4j-0.9.2/src";
public static final String ZEPPELIN_PYTHON_LIBS = "interpreter/lib/python";
public static final String DEFAULT_ZEPPELIN_PYTHON = "python";
public static final String MAX_RESULT = "zeppelin.python.maxResult";
private PythonZeppelinContext zeppelinContext;
private InterpreterContext context;
private Pattern errorInLastLine = Pattern.compile(".*(Error|Exception): .*$");
private String pythonPath;
private int maxResult;
private String py4jLibPath;
private String pythonLibPath;
private String pythonCommand;
private static final Logger LOGGER = LoggerFactory.getLogger(PythonInterpreter.class);
private static final int MAX_TIMEOUT_SEC = 10;
private GatewayServer gatewayServer;
private DefaultExecutor executor;
private int port;
private File pythonWorkDir;
protected boolean useBuiltinPy4j = true;
// used to forward output from python process to InterpreterOutput
private InterpreterOutputStream outputStream;
private BufferedWriter ins;
private PipedInputStream in;
private ByteArrayOutputStream input;
private String scriptPath;
boolean pythonscriptRunning = false;
private static final int MAX_TIMEOUT_SEC = 10;
private long pythonPid = 0;
private AtomicBoolean pythonScriptRunning = new AtomicBoolean(false);
private AtomicBoolean pythonScriptInitialized = new AtomicBoolean(false);
private long pythonPid = -1;
private IPythonInterpreter iPythonInterpreter;
Integer statementSetNotifier = new Integer(0);
private BaseZeppelinContext zeppelinContext;
private String condaPythonExec; // set by PythonCondaInterpreter
public PythonInterpreter(Properties property) {
super(property);
try {
File scriptFile = File.createTempFile("zeppelin_python-", ".py", new File("/tmp"));
scriptPath = scriptFile.getAbsolutePath();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private String workingDir() {
URL myURL = getClass().getProtectionDomain().getCodeSource().getLocation();
java.net.URI myURI = null;
try {
myURI = myURL.toURI();
} catch (URISyntaxException e1)
{}
String path = java.nio.file.Paths.get(myURI).toFile().toString();
return path;
}
private void createPythonScript() throws InterpreterException {
File out = new File(scriptPath);
if (out.exists() && out.isDirectory()) {
throw new InterpreterException("Can't create python script " + out.getAbsolutePath());
}
copyFile(out, ZEPPELIN_PYTHON);
// copy zeppelin_context.py as well
File zOut = new File(out.getParent() + "/zeppelin_context.py");
copyFile(zOut, ZEPPELIN_CONTEXT);
logger.info("File {} , {} created", scriptPath, zOut.getAbsolutePath());
}
public String getScriptPath() {
return scriptPath;
}
private void copyFile(File out, String sourceFile) throws InterpreterException {
ClassLoader classLoader = getClass().getClassLoader();
try {
FileOutputStream outStream = new FileOutputStream(out);
IOUtils.copy(
classLoader.getResourceAsStream(sourceFile),
outStream);
outStream.close();
} catch (IOException e) {
throw new InterpreterException(e);
}
}
private void createGatewayServerAndStartScript()
throws UnknownHostException, InterpreterException {
createPythonScript();
if (System.getenv("ZEPPELIN_HOME") != null) {
py4jLibPath = System.getenv("ZEPPELIN_HOME") + File.separator + ZEPPELIN_PY4JPATH;
pythonLibPath = System.getenv("ZEPPELIN_HOME") + File.separator + ZEPPELIN_PYTHON_LIBS;
} else {
Path workingPath = Paths.get("..").toAbsolutePath();
py4jLibPath = workingPath + File.separator + ZEPPELIN_PY4JPATH;
pythonLibPath = workingPath + File.separator + ZEPPELIN_PYTHON_LIBS;
}
port = findRandomOpenPortOnAllLocalInterfaces();
gatewayServer = new GatewayServer(this,
port,
GatewayServer.DEFAULT_PYTHON_PORT,
InetAddress.getByName("0.0.0.0"),
InetAddress.getByName("0.0.0.0"),
GatewayServer.DEFAULT_CONNECT_TIMEOUT,
GatewayServer.DEFAULT_READ_TIMEOUT,
(List) null);
gatewayServer.start();
// Run python shell
String pythonCmd = getPythonCommand();
CommandLine cmd = CommandLine.parse(pythonCmd);
if (!pythonCmd.endsWith(".py")) {
// PythonDockerInterpreter set pythoncmd with script
cmd.addArgument(getScriptPath(), false);
}
cmd.addArgument(Integer.toString(port), false);
cmd.addArgument(getLocalIp(), false);
executor = new DefaultExecutor();
outputStream = new InterpreterOutputStream(LOG);
PipedOutputStream ps = new PipedOutputStream();
in = null;
try {
in = new PipedInputStream(ps);
} catch (IOException e1) {
throw new InterpreterException(e1);
}
ins = new BufferedWriter(new OutputStreamWriter(ps));
input = new ByteArrayOutputStream();
PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream, outputStream, in);
executor.setStreamHandler(streamHandler);
executor.setWatchdog(new ExecuteWatchdog(ExecuteWatchdog.INFINITE_TIMEOUT));
try {
Map env = EnvironmentUtils.getProcEnvironment();
if (!env.containsKey("PYTHONPATH")) {
env.put("PYTHONPATH", py4jLibPath + File.pathSeparator + pythonLibPath);
} else {
env.put("PYTHONPATH", env.get("PYTHONPATH") + File.pathSeparator +
py4jLibPath + File.pathSeparator + pythonLibPath);
}
logger.info("cmd = {}", cmd.toString());
executor.execute(cmd, env, this);
pythonscriptRunning = true;
} catch (IOException e) {
throw new InterpreterException(e);
}
try {
input.write("import sys, getopt\n".getBytes());
ins.flush();
} catch (IOException e) {
throw new InterpreterException(e);
}
}
@Override
public void open() throws InterpreterException {
// try IPythonInterpreter first. If it is not available, we will fallback to the original
// python interpreter implementation.
// try IPythonInterpreter first
iPythonInterpreter = getIPythonInterpreter();
this.zeppelinContext = new PythonZeppelinContext(
getInterpreterGroup().getInterpreterHookRegistry(),
Integer.parseInt(getProperty("zeppelin.python.maxResult", "1000")));
if (getProperty("zeppelin.python.useIPython", "true").equals("true") &&
StringUtils.isEmpty(iPythonInterpreter.checkIPythonPrerequisite(getPythonBindPath()))) {
StringUtils.isEmpty(
iPythonInterpreter.checkIPythonPrerequisite(getPythonExec()))) {
try {
iPythonInterpreter.open();
LOG.info("IPython is available, Use IPythonInterpreter to replace PythonInterpreter");
LOGGER.info("IPython is available, Use IPythonInterpreter to replace PythonInterpreter");
return;
} catch (Exception e) {
iPythonInterpreter = null;
LOG.warn("Fail to open IPythonInterpreter", e);
LOGGER.warn("Fail to open IPythonInterpreter", e);
}
}
// reset iPythonInterpreter to null as it is not available
iPythonInterpreter = null;
LOG.info("IPython is not available, use the native PythonInterpreter");
LOGGER.info("IPython is not available, use the native PythonInterpreter");
// Add matplotlib display hook
InterpreterGroup intpGroup = getInterpreterGroup();
if (intpGroup != null && intpGroup.getInterpreterHookRegistry() != null) {
try {
// just for unit test I believe (zjffdu)
registerHook(HookType.POST_EXEC_DEV.getName(), "__zeppelin__._displayhook()");
} catch (InvalidHookException e) {
throw new InterpreterException(e);
}
}
// Add matplotlib display hook
try {
createGatewayServerAndStartScript();
} catch (UnknownHostException e) {
throw new InterpreterException(e);
} catch (IOException e) {
LOGGER.error("Fail to open PythonInterpreter", e);
throw new InterpreterException("Fail to open PythonInterpreter", e);
}
}
private IPythonInterpreter getIPythonInterpreter() {
LazyOpenInterpreter lazy = null;
IPythonInterpreter ipython = null;
Interpreter p = getInterpreterInTheSameSessionByClassName(IPythonInterpreter.class.getName());
// start gateway sever and start python process
private void createGatewayServerAndStartScript() throws IOException {
// start gateway server in JVM side
int port = RemoteInterpreterUtils.findRandomAvailablePortOnAllLocalInterfaces();
// use the FQDN as the server address instead of 127.0.0.1 so that python process in docker
// container can also connect to this gateway server.
String serverAddress = getLocalIP();
gatewayServer = new GatewayServer(this,
port,
GatewayServer.DEFAULT_PYTHON_PORT,
InetAddress.getByName(serverAddress),
InetAddress.getByName(serverAddress),
GatewayServer.DEFAULT_CONNECT_TIMEOUT,
GatewayServer.DEFAULT_READ_TIMEOUT,
(List) null);;
gatewayServer.start();
LOGGER.info("Starting GatewayServer at " + serverAddress + ":" + port);
while (p instanceof WrappedInterpreter) {
if (p instanceof LazyOpenInterpreter) {
lazy = (LazyOpenInterpreter) p;
}
p = ((WrappedInterpreter) p).getInnerInterpreter();
// launch python process to connect to the gateway server in JVM side
createPythonScript();
String pythonExec = getPythonExec();
CommandLine cmd = CommandLine.parse(pythonExec);
if (!pythonExec.endsWith(".py")) {
// PythonDockerInterpreter set pythonExec with script
cmd.addArgument(pythonWorkDir + "/zeppelin_python.py", false);
}
ipython = (IPythonInterpreter) p;
return ipython;
cmd.addArgument(serverAddress, false);
cmd.addArgument(Integer.toString(port), false);
executor = new DefaultExecutor();
outputStream = new InterpreterOutputStream(LOGGER);
PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream);
executor.setStreamHandler(streamHandler);
executor.setWatchdog(new ExecuteWatchdog(ExecuteWatchdog.INFINITE_TIMEOUT));
Map<String, String> env = setupPythonEnv();
LOGGER.info("Launching Python Process Command: " + cmd.getExecutable() +
" " + StringUtils.join(cmd.getArguments(), " "));
executor.execute(cmd, env, this);
pythonScriptRunning.set(true);
}
private void createPythonScript() throws IOException {
// set java.io.tmpdir to /tmp on MacOS, because docker can not share the /var folder which will
// cause PythonDockerInterpreter fails.
// https://stackoverflow.com/questions/45122459/docker-mounts-denied-the-paths-are-not-shared-
// from-os-x-and-are-not-known
if (System.getProperty("os.name", "").contains("Mac")) {
System.setProperty("java.io.tmpdir", "/tmp");
}
this.pythonWorkDir = Files.createTempDir();
this.pythonWorkDir.deleteOnExit();
LOGGER.info("Create Python working dir: " + pythonWorkDir.getAbsolutePath());
copyResourceToPythonWorkDir("python/zeppelin_python.py", "zeppelin_python.py");
copyResourceToPythonWorkDir("python/zeppelin_context.py", "zeppelin_context.py");
copyResourceToPythonWorkDir("python/backend_zinline.py", "backend_zinline.py");
copyResourceToPythonWorkDir("python/mpl_config.py", "mpl_config.py");
copyResourceToPythonWorkDir("python/py4j-src-0.9.2.zip", "py4j-src-0.9.2.zip");
}
protected boolean useIPython() {
return this.iPythonInterpreter != null;
}
private String getLocalIP() {
// zeppelin.python.gatewayserver_address is only for unit test on travis.
// Because the FQDN would fail unit test on travis ci.
String gatewayserver_address =
properties.getProperty("zeppelin.python.gatewayserver_address");
if (gatewayserver_address != null) {
return gatewayserver_address;
}
try {
return Inet4Address.getLocalHost().getHostAddress();
} catch (UnknownHostException e) {
LOGGER.warn("can't get local IP", e);
}
// fall back to loopback addreess
return "127.0.0.1";
}
private void copyResourceToPythonWorkDir(String srcResourceName,
String dstFileName) throws IOException {
FileOutputStream out = null;
try {
out = new FileOutputStream(pythonWorkDir.getAbsoluteFile() + "/" + dstFileName);
IOUtils.copy(
getClass().getClassLoader().getResourceAsStream(srcResourceName),
out);
} finally {
if (out != null) {
out.close();
}
}
}
protected Map<String, String> setupPythonEnv() throws IOException {
Map<String, String> env = EnvironmentUtils.getProcEnvironment();
appendToPythonPath(env, pythonWorkDir.getAbsolutePath());
if (useBuiltinPy4j) {
appendToPythonPath(env, pythonWorkDir.getAbsolutePath() + "/py4j-src-0.9.2.zip");
}
LOGGER.info("PYTHONPATH: " + env.get("PYTHONPATH"));
return env;
}
private void appendToPythonPath(Map<String, String> env, String path) {
if (!env.containsKey("PYTHONPATH")) {
env.put("PYTHONPATH", path);
} else {
env.put("PYTHONPATH", env.get("PYTHONPATH") + ":" + path);
}
}
// Run python script
// Choose python in the order of
// condaPythonExec > zeppelin.python
protected String getPythonExec() {
if (condaPythonExec != null) {
return condaPythonExec;
} else {
return getProperty("zeppelin.python", "python");
}
}
public File getPythonWorkDir() {
return pythonWorkDir;
}
@Override
@ -286,54 +263,58 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
iPythonInterpreter.close();
return;
}
pythonscriptRunning = false;
pythonScriptInitialized = false;
try {
ins.flush();
ins.close();
input.flush();
input.close();
} catch (IOException e) {
e.printStackTrace();
}
pythonScriptRunning.set(false);
pythonScriptInitialized.set(false);
executor.getWatchdog().destroyProcess();
new File(scriptPath).delete();
gatewayServer.shutdown();
// wait until getStatements stop
synchronized (statementSetNotifier) {
try {
statementSetNotifier.wait(1500);
} catch (InterruptedException e) {
}
statementSetNotifier.notify();
}
// reset these 2 monitors otherwise when you restart PythonInterpreter it would fails to execute
// python code as these 2 objects are in incorrect state.
statementSetNotifier = new Integer(0);
statementFinishedNotifier = new Integer(0);
}
private PythonInterpretRequest pythonInterpretRequest = null;
private Integer statementSetNotifier = new Integer(0);
private Integer statementFinishedNotifier = new Integer(0);
private String statementOutput = null;
private boolean statementError = false;
public void setPythonExec(String pythonExec) {
LOGGER.info("Set Python Command : {}", pythonExec);
this.condaPythonExec = pythonExec;
}
PythonInterpretRequest pythonInterpretRequest = null;
/**
* Result class of python interpreter
* Request send to Python Daemon
*/
public class PythonInterpretRequest {
public String statements;
public boolean isForCompletion;
public PythonInterpretRequest(String statements) {
public PythonInterpretRequest(String statements, boolean isForCompletion) {
this.statements = statements;
this.isForCompletion = isForCompletion;
}
public String statements() {
return statements;
}
public boolean isForCompletion() {
return isForCompletion;
}
}
// called by Python Process
public PythonInterpretRequest getStatements() {
synchronized (statementSetNotifier) {
while (pythonInterpretRequest == null && pythonscriptRunning && pythonScriptInitialized) {
while (pythonInterpretRequest == null) {
try {
statementSetNotifier.wait(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
PythonInterpretRequest req = pythonInterpretRequest;
@ -342,65 +323,78 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
}
}
String statementOutput = null;
boolean statementError = false;
Integer statementFinishedNotifier = new Integer(0);
// called by Python Process
public void setStatementsFinished(String out, boolean error) {
synchronized (statementFinishedNotifier) {
LOGGER.debug("Setting python statement output: " + out + ", error: " + error);
statementOutput = out;
statementError = error;
statementFinishedNotifier.notify();
}
}
boolean pythonScriptInitialized = false;
Integer pythonScriptInitializeNotifier = new Integer(0);
// called by Python Process
public void onPythonScriptInitialized(long pid) {
pythonPid = pid;
synchronized (pythonScriptInitializeNotifier) {
pythonScriptInitialized = true;
pythonScriptInitializeNotifier.notifyAll();
synchronized (pythonScriptInitialized) {
LOGGER.debug("onPythonScriptInitialized is called");
pythonScriptInitialized.set(true);
pythonScriptInitialized.notifyAll();
}
}
// called by Python Process
public void appendOutput(String message) throws IOException {
LOGGER.debug("Output from python process: " + message);
outputStream.getInterpreterOutput().write(message);
}
// used by subclass such as PySparkInterpreter to set JobGroup before executing spark code
protected void preCallPython(InterpreterContext context) {
}
// blocking call. Send python code to python process and get response
protected void callPython(PythonInterpretRequest request) {
synchronized (statementSetNotifier) {
this.pythonInterpretRequest = request;
statementOutput = null;
statementSetNotifier.notify();
}
synchronized (statementFinishedNotifier) {
while (statementOutput == null) {
try {
statementFinishedNotifier.wait(1000);
} catch (InterruptedException e) {
}
}
}
}
@Override
public InterpreterResult interpret(String cmd, InterpreterContext contextInterpreter)
public InterpreterResult interpret(String st, InterpreterContext context)
throws InterpreterException {
if (iPythonInterpreter != null) {
return iPythonInterpreter.interpret(cmd, contextInterpreter);
return iPythonInterpreter.interpret(st, context);
}
if (cmd == null || cmd.isEmpty()) {
return new InterpreterResult(Code.SUCCESS, "");
}
this.context = contextInterpreter;
zeppelinContext.setGui(context.getGui());
zeppelinContext.setNoteGui(context.getNoteGui());
zeppelinContext.setInterpreterContext(context);
if (!pythonscriptRunning) {
return new InterpreterResult(Code.ERROR, "python process not running"
+ outputStream.toString());
if (!pythonScriptRunning.get()) {
return new InterpreterResult(Code.ERROR, "python process not running "
+ outputStream.toString());
}
outputStream.setInterpreterOutput(context.out);
synchronized (pythonScriptInitializeNotifier) {
synchronized (pythonScriptInitialized) {
long startTime = System.currentTimeMillis();
while (pythonScriptInitialized == false
&& pythonscriptRunning
&& System.currentTimeMillis() - startTime < MAX_TIMEOUT_SEC * 1000) {
while (!pythonScriptInitialized.get()
&& System.currentTimeMillis() - startTime < MAX_TIMEOUT_SEC * 1000) {
try {
pythonScriptInitializeNotifier.wait(1000);
LOGGER.info("Wait for PythonScript initialized");
pythonScriptInitialized.wait(100);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
@ -413,59 +407,40 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
throw new InterpreterException(e);
}
if (pythonscriptRunning == false) {
// python script failed to initialize and terminated
errorMessage.add(new InterpreterResultMessage(
InterpreterResult.Type.TEXT, "failed to start python"));
return new InterpreterResult(Code.ERROR, errorMessage);
}
if (pythonScriptInitialized == false) {
if (!pythonScriptInitialized.get()) {
// timeout. didn't get initialized message
errorMessage.add(new InterpreterResultMessage(
InterpreterResult.Type.TEXT, "python is not responding"));
InterpreterResult.Type.TEXT, "Failed to initialize Python"));
return new InterpreterResult(Code.ERROR, errorMessage);
}
pythonInterpretRequest = new PythonInterpretRequest(cmd);
statementOutput = null;
BaseZeppelinContext z = getZeppelinContext();
z.setInterpreterContext(context);
z.setGui(context.getGui());
z.setNoteGui(context.getNoteGui());
InterpreterContext.set(context);
synchronized (statementSetNotifier) {
statementSetNotifier.notify();
}
synchronized (statementFinishedNotifier) {
while (statementOutput == null) {
try {
statementFinishedNotifier.wait(1000);
} catch (InterruptedException e) {
}
}
}
preCallPython(context);
callPython(new PythonInterpretRequest(st, false));
if (statementError) {
return new InterpreterResult(Code.ERROR, statementOutput);
} else {
try {
context.out.flush();
} catch (IOException e) {
throw new InterpreterException(e);
}
return new InterpreterResult(Code.SUCCESS);
}
}
public InterpreterContext getCurrentInterpreterContext() {
return context;
}
public void interrupt() throws IOException, InterpreterException {
if (pythonPid > -1) {
logger.info("Sending SIGINT signal to PID : " + pythonPid);
LOGGER.info("Sending SIGINT signal to PID : " + pythonPid);
Runtime.getRuntime().exec("kill -SIGINT " + pythonPid);
} else {
logger.warn("Non UNIX/Linux system, close the interpreter");
LOGGER.warn("Non UNIX/Linux system, close the interpreter");
close();
}
}
@ -474,11 +449,12 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
public void cancel(InterpreterContext context) throws InterpreterException {
if (iPythonInterpreter != null) {
iPythonInterpreter.cancel(context);
return;
}
try {
interrupt();
} catch (IOException e) {
e.printStackTrace();
LOGGER.error("Error", e);
}
}
@ -495,114 +471,162 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
return 0;
}
@Override
public Scheduler getScheduler() {
if (iPythonInterpreter != null) {
return iPythonInterpreter.getScheduler();
}
return SchedulerFactory.singleton().createOrGetFIFOScheduler(
PythonInterpreter.class.getName() + this.hashCode());
}
@Override
public List<InterpreterCompletion> completion(String buf, int cursor,
InterpreterContext interpreterContext) {
InterpreterContext interpreterContext)
throws InterpreterException {
if (iPythonInterpreter != null) {
return iPythonInterpreter.completion(buf, cursor, interpreterContext);
}
return null;
}
public void setPythonCommand(String cmd) {
logger.info("Set Python Command : {}", cmd);
pythonCommand = cmd;
}
private String getPythonCommand() {
if (pythonCommand == null) {
return getPythonBindPath();
} else {
return pythonCommand;
if (buf.length() < cursor) {
cursor = buf.length();
}
}
String completionString = getCompletionTargetString(buf, cursor);
String completionCommand = "__zeppelin_completion__.getCompletion('" + completionString + "')";
LOGGER.debug("completionCommand: " + completionCommand);
public String getPythonBindPath() {
String path = getProperty("zeppelin.python");
if (path == null) {
return DEFAULT_ZEPPELIN_PYTHON;
} else {
return path;
pythonInterpretRequest = new PythonInterpretRequest(completionCommand, true);
statementOutput = null;
synchronized (statementSetNotifier) {
statementSetNotifier.notify();
}
}
private Job getRunningJob(String paragraphId) {
Job foundJob = null;
Collection<Job> jobsRunning = getScheduler().getJobsRunning();
for (Job job : jobsRunning) {
if (job.getId().equals(paragraphId)) {
foundJob = job;
break;
String[] completionList = null;
synchronized (statementFinishedNotifier) {
long startTime = System.currentTimeMillis();
while (statementOutput == null
&& pythonScriptRunning.get()) {
try {
if (System.currentTimeMillis() - startTime > MAX_TIMEOUT_SEC * 1000) {
LOGGER.error("Python completion didn't have response for {}sec.", MAX_TIMEOUT_SEC);
break;
}
statementFinishedNotifier.wait(1000);
} catch (InterruptedException e) {
// not working
LOGGER.info("wait drop");
return new LinkedList<>();
}
}
if (statementError) {
return new LinkedList<>();
}
Gson gson = new Gson();
completionList = gson.fromJson(statementOutput, String[].class);
}
return foundJob;
//end code for completion
if (completionList == null) {
return new LinkedList<>();
}
List<InterpreterCompletion> results = new LinkedList<>();
for (String name: completionList) {
results.add(new InterpreterCompletion(name, name, StringUtils.EMPTY));
}
return results;
}
void bootStrapInterpreter(String file) throws IOException {
BufferedReader bootstrapReader = new BufferedReader(
new InputStreamReader(
PythonInterpreter.class.getResourceAsStream(file)));
String line = null;
String bootstrapCode = "";
private String getCompletionTargetString(String text, int cursor) {
String[] completionSeqCharaters = {" ", "\n", "\t"};
int completionEndPosition = cursor;
int completionStartPosition = cursor;
int indexOfReverseSeqPostion = cursor;
String resultCompletionText = "";
String completionScriptText = "";
try {
completionScriptText = text.substring(0, cursor);
}
catch (Exception e) {
LOGGER.error(e.toString());
return null;
}
completionEndPosition = completionScriptText.length();
String tempReverseCompletionText = new StringBuilder(completionScriptText).reverse().toString();
for (String seqCharacter : completionSeqCharaters) {
indexOfReverseSeqPostion = tempReverseCompletionText.indexOf(seqCharacter);
if (indexOfReverseSeqPostion < completionStartPosition && indexOfReverseSeqPostion > 0) {
completionStartPosition = indexOfReverseSeqPostion;
}
while ((line = bootstrapReader.readLine()) != null) {
bootstrapCode += line + "\n";
}
if (completionStartPosition == completionEndPosition) {
completionStartPosition = 0;
}
else
{
completionStartPosition = completionEndPosition - completionStartPosition;
}
resultCompletionText = completionScriptText.substring(
completionStartPosition , completionEndPosition);
return resultCompletionText;
}
protected IPythonInterpreter getIPythonInterpreter() {
LazyOpenInterpreter lazy = null;
IPythonInterpreter iPython = null;
Interpreter p = getInterpreterInTheSameSessionByClassName(IPythonInterpreter.class.getName());
while (p instanceof WrappedInterpreter) {
if (p instanceof LazyOpenInterpreter) {
lazy = (LazyOpenInterpreter) p;
}
p = ((WrappedInterpreter) p).getInnerInterpreter();
}
iPython = (IPythonInterpreter) p;
return iPython;
}
protected BaseZeppelinContext createZeppelinContext() {
return new PythonZeppelinContext(
getInterpreterGroup().getInterpreterHookRegistry(),
Integer.parseInt(getProperty("zeppelin.python.maxResult", "1000")));
}
public BaseZeppelinContext getZeppelinContext() {
if (zeppelinContext == null) {
zeppelinContext = createZeppelinContext();
}
return zeppelinContext;
}
protected void bootstrapInterpreter(String resourceName) throws IOException {
LOGGER.info("Bootstrap interpreter via " + resourceName);
String bootstrapCode =
IOUtils.toString(getClass().getClassLoader().getResourceAsStream(resourceName));
try {
interpret(bootstrapCode, context);
InterpreterResult result = interpret(bootstrapCode, InterpreterContext.get());
if (result.code() != Code.SUCCESS) {
throw new IOException("Fail to run bootstrap script: " + resourceName);
}
} catch (InterpreterException e) {
throw new IOException(e);
}
}
public PythonZeppelinContext getZeppelinContext() {
return zeppelinContext;
}
String getLocalIp() {
try {
return Inet4Address.getLocalHost().getHostAddress();
} catch (UnknownHostException e) {
logger.error("can't get local IP", e);
}
// fall back to loopback addreess
return "127.0.0.1";
}
private int findRandomOpenPortOnAllLocalInterfaces() {
Integer port = -1;
try (ServerSocket socket = new ServerSocket(0);) {
port = socket.getLocalPort();
socket.close();
} catch (IOException e) {
LOG.error("Can't find an open port", e);
}
return port;
}
public int getMaxResult() {
return maxResult;
}
@Override
public void onProcessComplete(int exitValue) {
pythonscriptRunning = false;
logger.info("python process terminated. exit code " + exitValue);
LOGGER.info("python process terminated. exit code " + exitValue);
pythonScriptRunning.set(false);
pythonScriptInitialized.set(false);
}
@Override
public void onProcessFailed(ExecuteException e) {
pythonscriptRunning = false;
logger.error("python process failed", e);
LOGGER.error("python process failed", e);
pythonScriptRunning.set(false);
pythonScriptInitialized.set(false);
}
// Called by Python Process, used for debugging purpose
public void logPythonOutput(String message) {
LOGGER.debug("Python Process Output: " + message);
}
}

View file

@ -70,7 +70,7 @@ public class PythonInterpreterPandasSql extends Interpreter {
LOG.info("Bootstrap {} interpreter with {}", this.toString(), SQL_BOOTSTRAP_FILE_PY);
PythonInterpreter python = getPythonInterpreter();
python.bootStrapInterpreter(SQL_BOOTSTRAP_FILE_PY);
python.bootstrapInterpreter(SQL_BOOTSTRAP_FILE_PY);
} catch (IOException e) {
LOG.error("Can't execute " + SQL_BOOTSTRAP_FILE_PY + " to import SQL dependencies", e);
}

Binary file not shown.

View file

@ -15,24 +15,12 @@
# limitations under the License.
#
import os, sys, getopt, traceback, json, re
import os, sys, traceback, json, re
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
from py4j.protocol import Py4JJavaError, Py4JNetworkError
import warnings
from py4j.protocol import Py4JJavaError
import ast
import traceback
import warnings
import signal
import base64
from io import BytesIO
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
# for back compatibility
class Logger(object):
def __init__(self):
@ -47,46 +35,79 @@ class Logger(object):
def flush(self):
pass
def handler_stop_signals(sig, frame):
sys.exit("Got signal : " + str(sig))
class PythonCompletion:
def __init__(self, interpreter, userNameSpace):
self.interpreter = interpreter
self.userNameSpace = userNameSpace
signal.signal(signal.SIGINT, handler_stop_signals)
def getObjectCompletion(self, text_value):
completions = [completion for completion in list(self.userNameSpace.keys()) if completion.startswith(text_value)]
builtinCompletions = [completion for completion in dir(__builtins__) if completion.startswith(text_value)]
return completions + builtinCompletions
host = "127.0.0.1"
if len(sys.argv) >= 3:
host = sys.argv[2]
def getMethodCompletion(self, objName, methodName):
execResult = locals()
try:
exec("{} = dir({})".format("objectDefList", objName), _zcUserQueryNameSpace, execResult)
except:
self.interpreter.logPythonOutput("Fail to run dir on " + objName)
self.interpreter.logPythonOutput(traceback.format_exc())
return None
else:
objectDefList = execResult['objectDefList']
return [completion for completion in execResult['objectDefList'] if completion.startswith(methodName)]
_zcUserQueryNameSpace = {}
client = GatewayClient(address=host, port=int(sys.argv[1]))
def getCompletion(self, text_value):
if text_value == None:
return None
gateway = JavaGateway(client)
dotPos = text_value.find(".")
if dotPos == -1:
objName = text_value
completionList = self.getObjectCompletion(objName)
else:
objName = text_value[:dotPos]
methodName = text_value[dotPos + 1:]
completionList = self.getMethodCompletion(objName, methodName)
if completionList is None or len(completionList) <= 0:
self.interpreter.setStatementsFinished("", False)
else:
result = json.dumps(list(filter(lambda x : not re.match("^__.*", x), list(completionList))))
self.interpreter.setStatementsFinished(result, False)
host = sys.argv[1]
port = int(sys.argv[2])
client = GatewayClient(address=host, port=port)
gateway = JavaGateway(client, auto_convert = True)
intp = gateway.entry_point
intp.onPythonScriptInitialized(os.getpid())
java_import(gateway.jvm, "org.apache.zeppelin.display.Input")
from zeppelin_context import PyZeppelinContext
z = __zeppelin__ = PyZeppelinContext(intp.getZeppelinContext(), gateway)
__zeppelin__._setup_matplotlib()
_zcUserQueryNameSpace["__zeppelin__"] = __zeppelin__
_zcUserQueryNameSpace["z"] = z
# redirect stdout/stderr to java side so that PythonInterpreter can capture the python execution result
output = Logger()
sys.stdout = output
#sys.stderr = output
sys.stderr = output
_zcUserQueryNameSpace = {}
completion = PythonCompletion(intp, _zcUserQueryNameSpace)
_zcUserQueryNameSpace["__zeppelin_completion__"] = completion
_zcUserQueryNameSpace["gateway"] = gateway
from zeppelin_context import PyZeppelinContext
if intp.getZeppelinContext():
z = __zeppelin__ = PyZeppelinContext(intp.getZeppelinContext(), gateway)
__zeppelin__._setup_matplotlib()
_zcUserQueryNameSpace["z"] = z
_zcUserQueryNameSpace["__zeppelin__"] = __zeppelin__
intp.onPythonScriptInitialized(os.getpid())
while True :
req = intp.getStatements()
if req == None:
break
try:
stmts = req.statements().split("\n")
final_code = []
isForCompletion = req.isForCompletion()
# Get post-execute hooks
try:
@ -98,35 +119,23 @@ while True :
user_hook = __zeppelin__.getHook('post_exec')
except:
user_hook = None
nhooks = 0
for hook in (global_hook, user_hook):
if hook:
nhooks += 1
if not isForCompletion:
for hook in (global_hook, user_hook):
if hook:
nhooks += 1
for s in stmts:
if s == None:
continue
# skip comment
s_stripped = s.strip()
if len(s_stripped) == 0 or s_stripped.startswith("#"):
continue
final_code.append(s)
if final_code:
if stmts:
# use exec mode to compile the statements except the last statement,
# so that the last statement's evaluation will be printed to stdout
code = compile('\n'.join(final_code), '<stdin>', 'exec', ast.PyCF_ONLY_AST, 1)
code = compile('\n'.join(stmts), '<stdin>', 'exec', ast.PyCF_ONLY_AST, 1)
to_run_hooks = []
if (nhooks > 0):
to_run_hooks = code.body[-nhooks:]
to_run_exec, to_run_single = (code.body[:-(nhooks + 1)],
[code.body[-(nhooks + 1)]])
try:
for node in to_run_exec:
mod = ast.Module([node])
@ -142,19 +151,37 @@ while True :
mod = ast.Module([node])
code = compile(mod, '<stdin>', 'exec')
exec(code, _zcUserQueryNameSpace)
except:
raise Exception(traceback.format_exc())
intp.setStatementsFinished("", False)
if not isForCompletion:
# only call it when it is not for code completion. code completion will call it in
# PythonCompletion.getCompletion
intp.setStatementsFinished("", False)
except Py4JJavaError:
# raise it to outside try except
raise
except:
if not isForCompletion:
# extract which line incur error from error message. e.g.
# Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
# ZeroDivisionError: integer division or modulo by zero
exception = traceback.format_exc()
m = re.search("File \"<stdin>\", line (\d+).*", exception)
if m:
line_no = int(m.group(1))
intp.setStatementsFinished(
"Fail to execute line {}: {}\n".format(line_no, stmts[line_no - 1]) + exception, True)
else:
intp.setStatementsFinished(exception, True)
else:
intp.setStatementsFinished("", False)
except Py4JJavaError:
excInnerError = traceback.format_exc() # format_tb() does not return the inner exception
innerErrorStart = excInnerError.find("Py4JJavaError:")
if innerErrorStart > -1:
excInnerError = excInnerError[innerErrorStart:]
excInnerError = excInnerError[innerErrorStart:]
intp.setStatementsFinished(excInnerError + str(sys.exc_info()), True)
except Py4JNetworkError:
# lost connection from gateway server. exit
sys.exit(1)
except:
intp.setStatementsFinished(traceback.format_exc(), True)

View file

@ -0,0 +1,331 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zeppelin.python;
import org.apache.zeppelin.display.GUI;
import org.apache.zeppelin.display.ui.CheckBox;
import org.apache.zeppelin.display.ui.Select;
import org.apache.zeppelin.display.ui.TextBox;
import org.apache.zeppelin.interpreter.Interpreter;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterException;
import org.apache.zeppelin.interpreter.InterpreterGroup;
import org.apache.zeppelin.interpreter.InterpreterOutput;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterResultMessage;
import org.apache.zeppelin.interpreter.remote.RemoteEventClient;
import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
import org.apache.zeppelin.user.AuthenticationInfo;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
public abstract class BasePythonInterpreterTest {
protected InterpreterGroup intpGroup;
protected Interpreter interpreter;
@Before
public abstract void setUp() throws InterpreterException;
@After
public abstract void tearDown() throws InterpreterException;
@Test
public void testPythonBasics() throws InterpreterException, InterruptedException, IOException {
InterpreterContext context = getInterpreterContext();
InterpreterResult result = interpreter.interpret("import sys\nprint(sys.version[0])", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
Thread.sleep(100);
List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
// single output without print
context = getInterpreterContext();
result = interpreter.interpret("'hello world'", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("'hello world'", interpreterResultMessages.get(0).getData().trim());
// unicode
context = getInterpreterContext();
result = interpreter.interpret("print(u'你好')", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("你好\n", interpreterResultMessages.get(0).getData());
// only the last statement is printed
context = getInterpreterContext();
result = interpreter.interpret("'hello world'\n'hello world2'", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("'hello world2'", interpreterResultMessages.get(0).getData().trim());
// single output
context = getInterpreterContext();
result = interpreter.interpret("print('hello world')", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("hello world\n", interpreterResultMessages.get(0).getData());
// multiple output
context = getInterpreterContext();
result = interpreter.interpret("print('hello world')\nprint('hello world2')", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("hello world\nhello world2\n", interpreterResultMessages.get(0).getData());
// assignment
context = getInterpreterContext();
result = interpreter.interpret("abc=1",context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(0, interpreterResultMessages.size());
// if block
context = getInterpreterContext();
result = interpreter.interpret("if abc > 0:\n\tprint('True')\nelse:\n\tprint('False')", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("True\n", interpreterResultMessages.get(0).getData());
// for loop
context = getInterpreterContext();
result = interpreter.interpret("for i in range(3):\n\tprint(i)", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("0\n1\n2\n", interpreterResultMessages.get(0).getData());
// syntax error
context = getInterpreterContext();
result = interpreter.interpret("print(unknown)", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.ERROR, result.code());
if (interpreter instanceof IPythonInterpreter) {
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertTrue(interpreterResultMessages.get(0).getData().contains("name 'unknown' is not defined"));
} else if (interpreter instanceof PythonInterpreter) {
assertTrue(result.message().get(0).getData().contains("name 'unknown' is not defined"));
}
// raise runtime exception
context = getInterpreterContext();
result = interpreter.interpret("1/0", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.ERROR, result.code());
if (interpreter instanceof IPythonInterpreter) {
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertTrue(interpreterResultMessages.get(0).getData().contains("ZeroDivisionError"));
} else if (interpreter instanceof PythonInterpreter) {
assertTrue(result.message().get(0).getData().contains("ZeroDivisionError"));
}
// ZEPPELIN-1133
context = getInterpreterContext();
result = interpreter.interpret(
"from __future__ import print_function\n" +
"def greet(name):\n" +
" print('Hello', name)\n" +
"greet('Jack')", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("Hello Jack\n",interpreterResultMessages.get(0).getData());
// ZEPPELIN-1114
context = getInterpreterContext();
result = interpreter.interpret("print('there is no Error: ok')", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("there is no Error: ok\n", interpreterResultMessages.get(0).getData());
}
@Test
public void testCodeCompletion() throws InterpreterException, IOException, InterruptedException {
// there's no completion for 'a.' because it is not recognized by compiler for now.
InterpreterContext context = getInterpreterContext();
String st = "a='hello'\na.";
List<InterpreterCompletion> completions = interpreter.completion(st, st.length(), context);
assertEquals(0, completions.size());
// define `a` first
context = getInterpreterContext();
st = "a='hello'";
InterpreterResult result = interpreter.interpret(st, context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
// now we can get the completion for `a.`
context = getInterpreterContext();
st = "a.";
completions = interpreter.completion(st, st.length(), context);
// it is different for python2 and python3 and may even different for different minor version
// so only verify it is larger than 20
assertTrue(completions.size() > 20);
context = getInterpreterContext();
st = "a.co";
completions = interpreter.completion(st, st.length(), context);
assertEquals(1, completions.size());
assertEquals("count", completions.get(0).getValue());
// cursor is in the middle of code
context = getInterpreterContext();
st = "a.co\b='hello";
completions = interpreter.completion(st, 4, context);
assertEquals(1, completions.size());
assertEquals("count", completions.get(0).getValue());
}
@Test
public void testZeppelinContext() throws InterpreterException, InterruptedException, IOException {
// TextBox
InterpreterContext context = getInterpreterContext();
InterpreterResult result = interpreter.interpret("z.input(name='text_1', defaultValue='value_1')", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage();
assertTrue(interpreterResultMessages.get(0).getData().contains("'value_1'"));
assertEquals(1, context.getGui().getForms().size());
assertTrue(context.getGui().getForms().get("text_1") instanceof TextBox);
TextBox textbox = (TextBox) context.getGui().getForms().get("text_1");
assertEquals("text_1", textbox.getName());
assertEquals("value_1", textbox.getDefaultValue());
// Select
context = getInterpreterContext();
result = interpreter.interpret("z.select(name='select_1', options=[('value_1', 'name_1'), ('value_2', 'name_2')])", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
assertEquals(1, context.getGui().getForms().size());
assertTrue(context.getGui().getForms().get("select_1") instanceof Select);
Select select = (Select) context.getGui().getForms().get("select_1");
assertEquals("select_1", select.getName());
assertEquals(2, select.getOptions().length);
assertEquals("name_1", select.getOptions()[0].getDisplayName());
assertEquals("value_1", select.getOptions()[0].getValue());
// CheckBox
context = getInterpreterContext();
result = interpreter.interpret("z.checkbox(name='checkbox_1', options=[('value_1', 'name_1'), ('value_2', 'name_2')])", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
assertEquals(1, context.getGui().getForms().size());
assertTrue(context.getGui().getForms().get("checkbox_1") instanceof CheckBox);
CheckBox checkbox = (CheckBox) context.getGui().getForms().get("checkbox_1");
assertEquals("checkbox_1", checkbox.getName());
assertEquals(2, checkbox.getOptions().length);
assertEquals("name_1", checkbox.getOptions()[0].getDisplayName());
assertEquals("value_1", checkbox.getOptions()[0].getValue());
// Pandas DataFrame
context = getInterpreterContext();
result = interpreter.interpret("import pandas as pd\ndf = pd.DataFrame({'id':[1,2,3], 'name':['a','b','c']})\nz.show(df)", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType());
assertEquals("id\tname\n1\ta\n2\tb\n3\tc\n", interpreterResultMessages.get(0).getData());
context = getInterpreterContext();
result = interpreter.interpret("import pandas as pd\ndf = pd.DataFrame({'id':[1,2,3,4], 'name':['a','b','c', 'd']})\nz.show(df)", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(2, interpreterResultMessages.size());
assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType());
assertEquals("id\tname\n1\ta\n2\tb\n3\tc\n", interpreterResultMessages.get(0).getData());
assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(1).getType());
assertEquals("<font color=red>Results are limited by 3.</font>\n", interpreterResultMessages.get(1).getData());
// z.show(matplotlib)
context = getInterpreterContext();
result = interpreter.interpret("import matplotlib.pyplot as plt\ndata=[1,1,2,3,4]\nplt.figure()\nplt.plot(data)\nz.show(plt)", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(0).getType());
// clear output
context = getInterpreterContext();
result = interpreter.interpret("import time\nprint(\"Hello\")\ntime.sleep(0.5)\nz.getInterpreterContext().out().clear()\nprint(\"world\")\n", context);
assertEquals("%text world\n", context.out.getCurrentOutput().toString());
}
@Test
public void testRedefinitionZeppelinContext() throws InterpreterException {
String redefinitionCode = "z = 1\n";
String restoreCode = "z = __zeppelin__\n";
String validCode = "z.input(\"test\")\n";
assertEquals(InterpreterResult.Code.SUCCESS, interpreter.interpret(validCode, getInterpreterContext()).code());
assertEquals(InterpreterResult.Code.SUCCESS, interpreter.interpret(redefinitionCode, getInterpreterContext()).code());
assertEquals(InterpreterResult.Code.ERROR, interpreter.interpret(validCode, getInterpreterContext()).code());
assertEquals(InterpreterResult.Code.SUCCESS, interpreter.interpret(restoreCode, getInterpreterContext()).code());
assertEquals(InterpreterResult.Code.SUCCESS, interpreter.interpret(validCode, getInterpreterContext()).code());
}
protected InterpreterContext getInterpreterContext() {
return new InterpreterContext(
"noteId",
"paragraphId",
"replName",
"paragraphTitle",
"paragraphText",
new AuthenticationInfo(),
new HashMap<String, Object>(),
new GUI(),
new GUI(),
null,
null,
null,
new InterpreterOutput(null));
}
protected InterpreterContext getInterpreterContext(RemoteEventClient mockRemoteEventClient) {
InterpreterContext context = getInterpreterContext();
context.setClient(mockRemoteEventClient);
return context;
}
}

View file

@ -17,288 +17,64 @@
package org.apache.zeppelin.python;
import org.apache.zeppelin.display.GUI;
import org.apache.zeppelin.display.ui.CheckBox;
import org.apache.zeppelin.display.ui.Select;
import org.apache.zeppelin.display.ui.TextBox;
import org.apache.zeppelin.interpreter.Interpreter;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterException;
import org.apache.zeppelin.interpreter.InterpreterGroup;
import org.apache.zeppelin.interpreter.InterpreterOutput;
import org.apache.zeppelin.interpreter.InterpreterOutputListener;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterResultMessage;
import org.apache.zeppelin.interpreter.InterpreterResultMessageOutput;
import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
import org.apache.zeppelin.user.AuthenticationInfo;
import org.junit.After;
import org.junit.Before;
import org.apache.zeppelin.interpreter.LazyOpenInterpreter;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.CopyOnWriteArrayList;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.mockito.Mockito.mock;
public class IPythonInterpreterTest {
public class IPythonInterpreterTest extends BasePythonInterpreterTest {
private static final Logger LOGGER = LoggerFactory.getLogger(IPythonInterpreterTest.class);
private IPythonInterpreter interpreter;
public void startInterpreter(Properties properties) throws InterpreterException {
interpreter = new IPythonInterpreter(properties);
InterpreterGroup mockInterpreterGroup = mock(InterpreterGroup.class);
interpreter.setInterpreterGroup(mockInterpreterGroup);
protected Properties initIntpProperties() {
Properties properties = new Properties();
properties.setProperty("zeppelin.python.maxResult", "3");
properties.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1");
return properties;
}
protected void startInterpreter(Properties properties) throws InterpreterException {
interpreter = new LazyOpenInterpreter(new IPythonInterpreter(properties));
intpGroup = new InterpreterGroup();
intpGroup.put("session_1", new ArrayList<Interpreter>());
intpGroup.get("session_1").add(interpreter);
interpreter.setInterpreterGroup(intpGroup);
interpreter.open();
}
@After
public void close() throws InterpreterException {
interpreter.close();
@Override
public void setUp() throws InterpreterException {
Properties properties = initIntpProperties();
startInterpreter(properties);
}
@Test
public void testIPython() throws IOException, InterruptedException, InterpreterException {
Properties properties = new Properties();
properties.setProperty("zeppelin.python.maxResult", "3");
startInterpreter(properties);
testInterpreter(interpreter);
@Override
public void tearDown() throws InterpreterException {
intpGroup.close();
}
@Test
public void testGrpcFrameSize() throws InterpreterException, IOException {
Properties properties = new Properties();
properties.setProperty("zeppelin.ipython.grpc.message_size", "200");
startInterpreter(properties);
// to make this test can run under both python2 and python3
InterpreterResult result = interpreter.interpret("from __future__ import print_function", getInterpreterContext());
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
InterpreterContext context = getInterpreterContext();
result = interpreter.interpret("print('1'*300)", context);
assertEquals(InterpreterResult.Code.ERROR, result.code());
List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertTrue(interpreterResultMessages.get(0).getData().contains("Frame size 304 exceeds maximum: 200"));
// next call continue work
result = interpreter.interpret("print(1)", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
close();
// increase framesize to make it work
properties.setProperty("zeppelin.ipython.grpc.message_size", "500");
startInterpreter(properties);
// to make this test can run under both python2 and python3
result = interpreter.interpret("from __future__ import print_function", getInterpreterContext());
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
context = getInterpreterContext();
result = interpreter.interpret("print('1'*300)", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
}
public static void testInterpreter(final Interpreter interpreter) throws IOException, InterruptedException, InterpreterException {
// to make this test can run under both python2 and python3
InterpreterResult result = interpreter.interpret("from __future__ import print_function", getInterpreterContext());
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
InterpreterContext context = getInterpreterContext();
result = interpreter.interpret("import sys\nprint(sys.version[0])", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
Thread.sleep(100);
List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
boolean isPython2 = interpreterResultMessages.get(0).getData().equals("2\n");
// single output without print
context = getInterpreterContext();
result = interpreter.interpret("'hello world'", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("'hello world'", interpreterResultMessages.get(0).getData());
// unicode
context = getInterpreterContext();
result = interpreter.interpret("print(u'你好')", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("你好\n", interpreterResultMessages.get(0).getData());
// only the last statement is printed
context = getInterpreterContext();
result = interpreter.interpret("'hello world'\n'hello world2'", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("'hello world2'", interpreterResultMessages.get(0).getData());
// single output
context = getInterpreterContext();
result = interpreter.interpret("print('hello world')", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("hello world\n", interpreterResultMessages.get(0).getData());
// multiple output
context = getInterpreterContext();
result = interpreter.interpret("print('hello world')\nprint('hello world2')", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("hello world\nhello world2\n", interpreterResultMessages.get(0).getData());
// assignment
context = getInterpreterContext();
result = interpreter.interpret("abc=1",context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(0, interpreterResultMessages.size());
// if block
context = getInterpreterContext();
result = interpreter.interpret("if abc > 0:\n\tprint('True')\nelse:\n\tprint('False')", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("True\n", interpreterResultMessages.get(0).getData());
// for loop
context = getInterpreterContext();
result = interpreter.interpret("for i in range(3):\n\tprint(i)", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("0\n1\n2\n", interpreterResultMessages.get(0).getData());
// syntax error
context = getInterpreterContext();
result = interpreter.interpret("print(unknown)", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.ERROR, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertTrue(interpreterResultMessages.get(0).getData().contains("name 'unknown' is not defined"));
// raise runtime exception
context = getInterpreterContext();
result = interpreter.interpret("1/0", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.ERROR, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertTrue(interpreterResultMessages.get(0).getData().contains("ZeroDivisionError"));
// ZEPPELIN-1133
context = getInterpreterContext();
result = interpreter.interpret("def greet(name):\n" +
" print('Hello', name)\n" +
"greet('Jack')", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("Hello Jack\n",interpreterResultMessages.get(0).getData());
// ZEPPELIN-1114
context = getInterpreterContext();
result = interpreter.interpret("print('there is no Error: ok')", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals("there is no Error: ok\n", interpreterResultMessages.get(0).getData());
// completion
context = getInterpreterContext();
List<InterpreterCompletion> completions = interpreter.completion("ab", 2, context);
assertEquals(2, completions.size());
assertEquals("abc", completions.get(0).getValue());
assertEquals("abs", completions.get(1).getValue());
context = getInterpreterContext();
interpreter.interpret("import sys", context);
completions = interpreter.completion("sys.", 4, context);
assertFalse(completions.isEmpty());
context = getInterpreterContext();
completions = interpreter.completion("sys.std", 7, context);
for (InterpreterCompletion completion : completions) {
System.out.println(completion.getValue());
}
assertEquals(3, completions.size());
assertEquals("stderr", completions.get(0).getValue());
assertEquals("stdin", completions.get(1).getValue());
assertEquals("stdout", completions.get(2).getValue());
// there's no completion for 'a.' because it is not recognized by compiler for now.
context = getInterpreterContext();
String st = "a='hello'\na.";
completions = interpreter.completion(st, st.length(), context);
assertEquals(0, completions.size());
// define `a` first
context = getInterpreterContext();
st = "a='hello'";
result = interpreter.interpret(st, context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(0, interpreterResultMessages.size());
// now we can get the completion for `a.`
context = getInterpreterContext();
st = "a.";
completions = interpreter.completion(st, st.length(), context);
// it is different for python2 and python3 and may even different for different minor version
// so only verify it is larger than 20
assertTrue(completions.size() > 20);
context = getInterpreterContext();
st = "a.co";
completions = interpreter.completion(st, st.length(), context);
assertEquals(1, completions.size());
assertEquals("count", completions.get(0).getValue());
// cursor is in the middle of code
context = getInterpreterContext();
st = "a.co\b='hello";
completions = interpreter.completion(st, 4, context);
assertEquals(1, completions.size());
assertEquals("count", completions.get(0).getValue());
public void testIPythonAdvancedFeatures() throws InterpreterException, InterruptedException, IOException {
// ipython help
context = getInterpreterContext();
result = interpreter.interpret("range?", context);
InterpreterContext context = getInterpreterContext();
InterpreterResult result = interpreter.interpret("range?", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage();
assertTrue(interpreterResultMessages.get(0).getData().contains("range(stop)"));
// timeit
@ -331,13 +107,16 @@ public class IPythonInterpreterTest {
assertEquals(InterpreterResult.Code.ERROR, result.code());
interpreterResultMessages = context2.out.toInterpreterResultMessage();
assertTrue(interpreterResultMessages.get(0).getData().contains("KeyboardInterrupt"));
}
@Test
public void testIPythonPlotting() throws InterpreterException, InterruptedException, IOException {
// matplotlib
context = getInterpreterContext();
result = interpreter.interpret("%matplotlib inline\nimport matplotlib.pyplot as plt\ndata=[1,1,2,3,4]\nplt.figure()\nplt.plot(data)", context);
InterpreterContext context = getInterpreterContext();
InterpreterResult result = interpreter.interpret("%matplotlib inline\nimport matplotlib.pyplot as plt\ndata=[1,1,2,3,4]\nplt.figure()\nplt.plot(data)", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage();
// the order of IMAGE and TEXT is not determined
// check there must be one IMAGE output
boolean hasImageOutput = false;
@ -411,94 +190,44 @@ public class IPythonInterpreterTest {
}
}
assertTrue("No Image Output", hasImageOutput);
}
// ZeppelinContext
@Test
public void testGrpcFrameSize() throws InterpreterException, IOException {
tearDown();
// TextBox
context = getInterpreterContext();
result = interpreter.interpret("z.input(name='text_1', defaultValue='value_1')", context);
Thread.sleep(100);
Properties properties = initIntpProperties();
properties.setProperty("zeppelin.ipython.grpc.message_size", "3000");
startInterpreter(properties);
// to make this test can run under both python2 and python3
InterpreterResult result = interpreter.interpret("from __future__ import print_function", getInterpreterContext());
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertTrue(interpreterResultMessages.get(0).getData().contains("'value_1'"));
assertEquals(1, context.getGui().getForms().size());
assertTrue(context.getGui().getForms().get("text_1") instanceof TextBox);
TextBox textbox = (TextBox) context.getGui().getForms().get("text_1");
assertEquals("text_1", textbox.getName());
assertEquals("value_1", textbox.getDefaultValue());
// Select
context = getInterpreterContext();
result = interpreter.interpret("z.select(name='select_1', options=[('value_1', 'name_1'), ('value_2', 'name_2')])", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
assertEquals(1, context.getGui().getForms().size());
assertTrue(context.getGui().getForms().get("select_1") instanceof Select);
Select select = (Select) context.getGui().getForms().get("select_1");
assertEquals("select_1", select.getName());
assertEquals(2, select.getOptions().length);
assertEquals("name_1", select.getOptions()[0].getDisplayName());
assertEquals("value_1", select.getOptions()[0].getValue());
// CheckBox
context = getInterpreterContext();
result = interpreter.interpret("z.checkbox(name='checkbox_1', options=[('value_1', 'name_1'), ('value_2', 'name_2')])", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
assertEquals(1, context.getGui().getForms().size());
assertTrue(context.getGui().getForms().get("checkbox_1") instanceof CheckBox);
CheckBox checkbox = (CheckBox) context.getGui().getForms().get("checkbox_1");
assertEquals("checkbox_1", checkbox.getName());
assertEquals(2, checkbox.getOptions().length);
assertEquals("name_1", checkbox.getOptions()[0].getDisplayName());
assertEquals("value_1", checkbox.getOptions()[0].getValue());
// Pandas DataFrame
context = getInterpreterContext();
result = interpreter.interpret("import pandas as pd\ndf = pd.DataFrame({'id':[1,2,3], 'name':['a','b','c']})\nz.show(df)", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
InterpreterContext context = getInterpreterContext();
result = interpreter.interpret("print('1'*3000)", context);
assertEquals(InterpreterResult.Code.ERROR, result.code());
List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType());
assertEquals("id\tname\n1\ta\n2\tb\n3\tc\n", interpreterResultMessages.get(0).getData());
assertTrue(interpreterResultMessages.get(0).getData().contains("exceeds maximum: 3000"));
context = getInterpreterContext();
result = interpreter.interpret("import pandas as pd\ndf = pd.DataFrame({'id':[1,2,3,4], 'name':['a','b','c', 'd']})\nz.show(df)", context);
// next call continue work
result = interpreter.interpret("print(1)", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(2, interpreterResultMessages.size());
assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType());
assertEquals("id\tname\n1\ta\n2\tb\n3\tc\n", interpreterResultMessages.get(0).getData());
assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(1).getType());
assertEquals("<font color=red>Results are limited by 3.</font>\n", interpreterResultMessages.get(1).getData());
// z.show(matplotlib)
context = getInterpreterContext();
result = interpreter.interpret("import matplotlib.pyplot as plt\ndata=[1,1,2,3,4]\nplt.figure()\nplt.plot(data)\nz.show(plt)", context);
tearDown();
// increase framesize to make it work
properties.setProperty("zeppelin.ipython.grpc.message_size", "5000");
startInterpreter(properties);
// to make this test can run under both python2 and python3
result = interpreter.interpret("from __future__ import print_function", getInterpreterContext());
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(2, interpreterResultMessages.size());
assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(0).getType());
assertEquals(InterpreterResult.Type.IMG, interpreterResultMessages.get(1).getType());
// clear output
context = getInterpreterContext();
result = interpreter.interpret("import time\nprint(\"Hello\")\ntime.sleep(0.5)\nz.getInterpreterContext().out().clear()\nprint(\"world\")\n", context);
assertEquals("%text world\n", context.out.getCurrentOutput().toString());
result = interpreter.interpret("print('1'*3000)", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
}
private static InterpreterContext getInterpreterContext() {
return new InterpreterContext(
"noteId",
"paragraphId",
"replName",
"paragraphTitle",
"paragraphText",
new AuthenticationInfo(),
new HashMap<String, Object>(),
new GUI(),
new GUI(),
null,
null,
null,
new InterpreterOutput(null));
}
}

View file

@ -39,7 +39,9 @@ public class PythonCondaInterpreterTest {
@Before
public void setUp() throws InterpreterException {
conda = spy(new PythonCondaInterpreter(new Properties()));
when(conda.getClassName()).thenReturn(PythonCondaInterpreter.class.getName());
python = mock(PythonInterpreter.class);
when(python.getClassName()).thenReturn(PythonInterpreter.class.getName());
InterpreterGroup group = new InterpreterGroup();
group.put("note", Arrays.asList(python, conda));
@ -79,7 +81,7 @@ public class PythonCondaInterpreterTest {
conda.interpret("activate " + envname, context);
verify(python, times(1)).open();
verify(python, times(1)).close();
verify(python).setPythonCommand("/path1/bin/python");
verify(python).setPythonExec("/path1/bin/python");
assertTrue(envname.equals(conda.getCurrentCondaEnvName()));
}
@ -89,7 +91,7 @@ public class PythonCondaInterpreterTest {
conda.interpret("deactivate", context);
verify(python, times(1)).open();
verify(python, times(1)).close();
verify(python).setPythonCommand("python");
verify(python).setPythonExec("python");
assertTrue(conda.getCurrentCondaEnvName().isEmpty());
}

View file

@ -17,24 +17,27 @@
package org.apache.zeppelin.python;
import org.apache.zeppelin.display.GUI;
import org.apache.zeppelin.interpreter.*;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterException;
import org.apache.zeppelin.interpreter.InterpreterGroup;
import org.apache.zeppelin.interpreter.InterpreterOutput;
import org.apache.zeppelin.user.AuthenticationInfo;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import java.io.IOException;
import java.net.Inet4Address;
import java.net.UnknownHostException;
import java.io.File;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Properties;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
public class PythonDockerInterpreterTest {
private PythonDockerInterpreter docker;
@ -52,7 +55,7 @@ public class PythonDockerInterpreterTest {
doReturn(true).when(docker).pull(any(InterpreterOutput.class), anyString());
doReturn(python).when(docker).getPythonInterpreter();
doReturn("/scriptpath/zeppelin_python.py").when(python).getScriptPath();
doReturn(new File("/scriptpath")).when(python).getPythonWorkDir();
docker.open();
}
@ -64,7 +67,7 @@ public class PythonDockerInterpreterTest {
verify(python, times(1)).open();
verify(python, times(1)).close();
verify(docker, times(1)).pull(any(InterpreterOutput.class), anyString());
verify(python).setPythonCommand(Mockito.matches("docker run -i --rm -v.*"));
verify(python).setPythonExec(Mockito.matches("docker run -i --rm -v.*"));
}
@Test
@ -73,7 +76,7 @@ public class PythonDockerInterpreterTest {
docker.interpret("deactivate", context);
verify(python, times(1)).open();
verify(python, times(1)).close();
verify(python).setPythonCommand(null);
verify(python).setPythonExec(null);
}
private InterpreterContext getInterpreterContext() {

View file

@ -17,130 +17,91 @@
package org.apache.zeppelin.python;
import static org.apache.zeppelin.python.PythonInterpreter.DEFAULT_ZEPPELIN_PYTHON;
import static org.apache.zeppelin.python.PythonInterpreter.MAX_RESULT;
import static org.apache.zeppelin.python.PythonInterpreter.ZEPPELIN_PYTHON;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Properties;
import org.apache.commons.exec.environment.EnvironmentUtils;
import org.apache.zeppelin.display.AngularObjectRegistry;
import org.apache.zeppelin.display.GUI;
import org.apache.zeppelin.interpreter.Interpreter;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterContextRunner;
import org.apache.zeppelin.interpreter.InterpreterException;
import org.apache.zeppelin.interpreter.InterpreterGroup;
import org.apache.zeppelin.interpreter.InterpreterOutput;
import org.apache.zeppelin.interpreter.InterpreterOutputListener;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterResultMessageOutput;
import org.apache.zeppelin.resource.LocalResourcePool;
import org.apache.zeppelin.user.AuthenticationInfo;
import org.junit.After;
import org.junit.Before;
import org.apache.zeppelin.interpreter.LazyOpenInterpreter;
import org.junit.Test;
public class PythonInterpreterTest implements InterpreterOutputListener {
PythonInterpreter pythonInterpreter = null;
String cmdHistory;
private InterpreterContext context;
InterpreterOutput out;
import java.io.IOException;
import java.util.LinkedList;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public static Properties getPythonTestProperties() {
Properties p = new Properties();
p.setProperty(ZEPPELIN_PYTHON, DEFAULT_ZEPPELIN_PYTHON);
p.setProperty(MAX_RESULT, "1000");
p.setProperty("zeppelin.python.useIPython", "false");
return p;
}
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
@Before
public void beforeTest() throws IOException, InterpreterException {
cmdHistory = "";
public class PythonInterpreterTest extends BasePythonInterpreterTest {
// python interpreter
pythonInterpreter = new PythonInterpreter(getPythonTestProperties());
@Override
public void setUp() throws InterpreterException {
// create interpreter group
InterpreterGroup group = new InterpreterGroup();
group.put("note", new LinkedList<Interpreter>());
group.get("note").add(pythonInterpreter);
pythonInterpreter.setInterpreterGroup(group);
intpGroup = new InterpreterGroup();
out = new InterpreterOutput(this);
Properties properties = new Properties();
properties.setProperty("zeppelin.python.maxResult", "3");
properties.setProperty("zeppelin.python.useIPython", "false");
properties.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1");
context = new InterpreterContext("note", "id", null, "title", "text",
new AuthenticationInfo(),
new HashMap<String, Object>(),
new GUI(),
new GUI(),
new AngularObjectRegistry(group.getId(), null),
new LocalResourcePool("id"),
new LinkedList<InterpreterContextRunner>(),
out);
InterpreterContext.set(context);
pythonInterpreter.open();
}
interpreter = new LazyOpenInterpreter(new PythonInterpreter(properties));
intpGroup.put("note", new LinkedList<Interpreter>());
intpGroup.get("note").add(interpreter);
interpreter.setInterpreterGroup(intpGroup);
@After
public void afterTest() throws IOException, InterpreterException {
pythonInterpreter.close();
}
@Test
public void testInterpret() throws InterruptedException, IOException, InterpreterException {
InterpreterResult result = pythonInterpreter.interpret("print (\"hi\")", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
}
@Test
public void testInterpretInvalidSyntax() throws IOException, InterpreterException {
InterpreterResult result = pythonInterpreter.interpret("for x in range(0,3): print (\"hi\")\n", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
assertTrue(new String(out.getOutputAt(0).toByteArray()).contains("hi\nhi\nhi"));
}
@Test
public void testRedefinitionZeppelinContext() throws InterpreterException {
String pyRedefinitionCode = "z = 1\n";
String pyRestoreCode = "z = __zeppelin__\n";
String pyValidCode = "z.input(\"test\")\n";
assertEquals(InterpreterResult.Code.SUCCESS, pythonInterpreter.interpret(pyValidCode, context).code());
assertEquals(InterpreterResult.Code.SUCCESS, pythonInterpreter.interpret(pyRedefinitionCode, context).code());
assertEquals(InterpreterResult.Code.ERROR, pythonInterpreter.interpret(pyValidCode, context).code());
assertEquals(InterpreterResult.Code.SUCCESS, pythonInterpreter.interpret(pyRestoreCode, context).code());
assertEquals(InterpreterResult.Code.SUCCESS, pythonInterpreter.interpret(pyValidCode, context).code());
}
@Test
public void testOutputClear() throws InterpreterException {
InterpreterResult result = pythonInterpreter.interpret("print(\"Hello\")\nz.getInterpreterContext().out().clear()\nprint(\"world\")\n", context);
assertEquals("%text world\n", out.getCurrentOutput().toString());
InterpreterContext.set(getInterpreterContext());
interpreter.open();
}
@Override
public void onUpdateAll(InterpreterOutput out) {
public void tearDown() throws InterpreterException {
intpGroup.close();
}
@Override
public void onAppend(int index, InterpreterResultMessageOutput out, byte[] line) {
public void testCodeCompletion() throws InterpreterException, IOException, InterruptedException {
super.testCodeCompletion();
//TODO(zjffdu) PythonInterpreter doesn't support this kind of code completion for now.
// completion
// InterpreterContext context = getInterpreterContext();
// List<InterpreterCompletion> completions = interpreter.completion("ab", 2, context);
// assertEquals(2, completions.size());
// assertEquals("abc", completions.get(0).getValue());
// assertEquals("abs", completions.get(1).getValue());
}
@Override
public void onUpdate(int index, InterpreterResultMessageOutput out) {
private class infinityPythonJob implements Runnable {
@Override
public void run() {
String code = "import time\nwhile True:\n time.sleep(1)" ;
InterpreterResult ret = null;
try {
ret = interpreter.interpret(code, getInterpreterContext());
} catch (InterpreterException e) {
e.printStackTrace();
}
assertNotNull(ret);
Pattern expectedMessage = Pattern.compile("KeyboardInterrupt");
Matcher m = expectedMessage.matcher(ret.message().toString());
assertTrue(m.find());
}
}
@Test
public void testCancelIntp() throws InterruptedException, InterpreterException {
assertEquals(InterpreterResult.Code.SUCCESS,
interpreter.interpret("a = 1\n", getInterpreterContext()).code());
Thread t = new Thread(new infinityPythonJob());
t.start();
Thread.sleep(5000);
interpreter.cancel(getInterpreterContext());
assertTrue(t.isAlive());
t.join(2000);
assertFalse(t.isAlive());
}
}

View file

@ -15,18 +15,13 @@
# limitations under the License.
#
# Direct log messages to stdout
log4j.appender.stdout=org.apache.log4j.ConsoleAppender
log4j.appender.stdout.Target=System.out
log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
log4j.appender.stdout.layout.ConversionPattern=%d{ABSOLUTE} %5p %c:%L - %m%n
#log4j.appender.stdout.layout.ConversionPattern=
#%5p [%t] (%F:%L) - %m%n
#%-4r [%t] %-5p %c %x - %m%n
#
# Root logger option
log4j.rootLogger=INFO, stdout
log4j.logger.org.apache.zeppelin.python.IPythonInterpreter=DEBUG
log4j.logger.org.apache.zeppelin.python.IPythonClient=DEBUG
log4j.logger.org.apache.zeppelin.python=DEBUG
# Direct log messages to stdout
log4j.appender.stdout=org.apache.log4j.ConsoleAppender
log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
log4j.appender.stdout.layout.ConversionPattern=%5p [%d] ({%t} %F[%M]:%L) - %m%n
log4j.logger.org.apache.zeppelin.python=DEBUG

View file

@ -441,14 +441,14 @@
<configuration>
<forkCount>1</forkCount>
<reuseForks>false</reuseForks>
<argLine>-Xmx1024m -XX:MaxPermSize=256m</argLine>
<argLine>-Xmx1536m -XX:MaxPermSize=256m</argLine>
<excludes>
<exclude>**/SparkRInterpreterTest.java</exclude>
<exclude>${pyspark.test.exclude}</exclude>
<exclude>${tests.to.exclude}</exclude>
</excludes>
<environmentVariables>
<PYTHONPATH>${project.build.directory}/../../../interpreter/spark/pyspark/pyspark.zip:${project.build.directory}/../../../interpreter/lib/python/:${project.build.directory}/../../../interpreter/spark/pyspark/py4j-${py4j.version}-src.zip:.</PYTHONPATH>
<PYTHONPATH>${project.build.directory}/../../../interpreter/spark/pyspark/pyspark.zip:${project.build.directory}/../../../interpreter/spark/pyspark/py4j-${py4j.version}-src.zip</PYTHONPATH>
<ZEPPELIN_HOME>${basedir}/../../</ZEPPELIN_HOME>
</environmentVariables>
</configuration>

View file

@ -27,6 +27,7 @@ import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.LazyOpenInterpreter;
import org.apache.zeppelin.interpreter.WrappedInterpreter;
import org.apache.zeppelin.python.IPythonInterpreter;
import org.apache.zeppelin.python.PythonInterpreter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -49,8 +50,8 @@ public class IPySparkInterpreter extends IPythonInterpreter {
@Override
public void open() throws InterpreterException {
setProperty("zeppelin.python",
PySparkInterpreter.getPythonExec(getProperties()));
PySparkInterpreter pySparkInterpreter = getPySparkInterpreter();
setProperty("zeppelin.python", pySparkInterpreter.getPythonExec());
sparkInterpreter = getSparkInterpreter();
SparkConf conf = sparkInterpreter.getSparkContext().getConf();
// only set PYTHONPATH in embedded, local or yarn-client mode.
@ -94,6 +95,16 @@ public class IPySparkInterpreter extends IPythonInterpreter {
return spark;
}
private PySparkInterpreter getPySparkInterpreter() throws InterpreterException {
PySparkInterpreter pySpark = null;
Interpreter p = getInterpreterInTheSameSessionByClassName(PySparkInterpreter.class.getName());
while (p instanceof WrappedInterpreter) {
p = ((WrappedInterpreter) p).getInnerInterpreter();
}
pySpark = (PySparkInterpreter) p;
return pySpark;
}
@Override
public BaseZeppelinContext buildZeppelinContext() {
return sparkInterpreter.getZeppelinContext();
@ -117,6 +128,7 @@ public class IPySparkInterpreter extends IPythonInterpreter {
@Override
public void close() throws InterpreterException {
LOGGER.info("Close IPySparkInterpreter");
super.close();
if (sparkInterpreter != null) {
sparkInterpreter.close();

View file

@ -56,7 +56,7 @@ import java.util.Properties;
*/
public class NewSparkInterpreter extends AbstractSparkInterpreter {
private static final Logger LOGGER = LoggerFactory.getLogger(SparkInterpreter.class);
private static final Logger LOGGER = LoggerFactory.getLogger(NewSparkInterpreter.class);
private BaseSparkScalaInterpreter innerInterpreter;
private Map<String, String> innerInterpreterClassMap = new HashMap<>();
@ -177,7 +177,10 @@ public class NewSparkInterpreter extends AbstractSparkInterpreter {
@Override
public void close() {
LOGGER.info("Close SparkInterpreter");
innerInterpreter.close();
if (innerInterpreter != null) {
innerInterpreter.close();
innerInterpreter = null;
}
}
@Override

View file

@ -30,6 +30,7 @@ import org.apache.commons.lang.StringUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;
import org.apache.zeppelin.interpreter.BaseZeppelinContext;
import org.apache.zeppelin.interpreter.Interpreter;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterException;
@ -44,6 +45,8 @@ import org.apache.zeppelin.interpreter.WrappedInterpreter;
import org.apache.zeppelin.interpreter.remote.RemoteInterpreterUtils;
import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
import org.apache.zeppelin.interpreter.util.InterpreterOutputStream;
import org.apache.zeppelin.python.IPythonInterpreter;
import org.apache.zeppelin.python.PythonInterpreter;
import org.apache.zeppelin.spark.dep.SparkDependencyContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -68,56 +71,23 @@ import java.util.Properties;
* features compared to IPySparkInterpreter, but requires less prerequisites than
* IPySparkInterpreter, only python is required.
*/
public class PySparkInterpreter extends Interpreter implements ExecuteResultHandler {
private static final Logger LOGGER = LoggerFactory.getLogger(PySparkInterpreter.class);
private static final int MAX_TIMEOUT_SEC = 10;
public class PySparkInterpreter extends PythonInterpreter {
private static Logger LOGGER = LoggerFactory.getLogger(PySparkInterpreter.class);
private GatewayServer gatewayServer;
private DefaultExecutor executor;
// used to forward output from python process to InterpreterOutput
private InterpreterOutputStream outputStream;
private String scriptPath;
private boolean pythonscriptRunning = false;
private long pythonPid = -1;
private IPySparkInterpreter iPySparkInterpreter;
private SparkInterpreter sparkInterpreter;
public PySparkInterpreter(Properties property) {
super(property);
this.useBuiltinPy4j = false;
}
@Override
public void open() throws InterpreterException {
// try IPySparkInterpreter first
iPySparkInterpreter = getIPySparkInterpreter();
if (getProperty("zeppelin.pyspark.useIPython", "true").equals("true") &&
StringUtils.isEmpty(
iPySparkInterpreter.checkIPythonPrerequisite(getPythonExec(getProperties())))) {
try {
iPySparkInterpreter.open();
LOGGER.info("IPython is available, Use IPySparkInterpreter to replace PySparkInterpreter");
return;
} catch (Exception e) {
iPySparkInterpreter = null;
LOGGER.warn("Fail to open IPySparkInterpreter", e);
}
}
setProperty("zeppelin.python.useIPython", getProperty("zeppelin.pyspark.useIPython", "true"));
// reset iPySparkInterpreter to null as it is not available
iPySparkInterpreter = null;
LOGGER.info("IPython is not available, use the native PySparkInterpreter\n");
// Add matplotlib display hook
InterpreterGroup intpGroup = getInterpreterGroup();
if (intpGroup != null && intpGroup.getInterpreterHookRegistry() != null) {
try {
// just for unit test I believe (zjffdu)
registerHook(HookType.POST_EXEC_DEV.getName(), "__zeppelin__._displayhook()");
} catch (InvalidHookException e) {
throw new InterpreterException(e);
}
}
// create SparkInterpreter in JVM side TODO(zjffdu) move to SparkInterpreter
DepInterpreter depInterpreter = getDepInterpreter();
// load libraries from Dependency Interpreter
URL [] urls = new URL[0];
List<URL> urlList = new LinkedList<>();
@ -159,126 +129,61 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
ClassLoader oldCl = Thread.currentThread().getContextClassLoader();
try {
URLClassLoader newCl = new URLClassLoader(urls, oldCl);
LOGGER.info("urls:" + urls);
for (URL url : urls) {
LOGGER.info("url:" + url);
}
Thread.currentThread().setContextClassLoader(newCl);
// create Python Process and JVM gateway
super.open();
// must create spark interpreter after ClassLoader is set, otherwise the additional jars
// can not be loaded by spark repl.
this.sparkInterpreter = getSparkInterpreter();
createGatewayServerAndStartScript();
} catch (IOException e) {
LOGGER.error("Fail to open PySparkInterpreter", e);
throw new InterpreterException("Fail to open PySparkInterpreter", e);
} finally {
Thread.currentThread().setContextClassLoader(oldCl);
}
}
private void createGatewayServerAndStartScript() throws IOException {
// start gateway server in JVM side
int port = RemoteInterpreterUtils.findRandomAvailablePortOnAllLocalInterfaces();
gatewayServer = new GatewayServer(this, port);
gatewayServer.start();
// launch python process to connect to the gateway server in JVM side
createPythonScript();
String pythonExec = getPythonExec(getProperties());
LOGGER.info("PythonExec: " + pythonExec);
CommandLine cmd = CommandLine.parse(pythonExec);
cmd.addArgument(scriptPath, false);
cmd.addArgument(Integer.toString(port), false);
cmd.addArgument(Integer.toString(sparkInterpreter.getSparkVersion().toNumber()), false);
executor = new DefaultExecutor();
outputStream = new InterpreterOutputStream(LOGGER);
PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream);
executor.setStreamHandler(streamHandler);
executor.setWatchdog(new ExecuteWatchdog(ExecuteWatchdog.INFINITE_TIMEOUT));
Map<String, String> env = setupPySparkEnv();
executor.execute(cmd, env, this);
pythonscriptRunning = true;
}
private void createPythonScript() throws IOException {
FileOutputStream pysparkScriptOutput = null;
FileOutputStream zeppelinContextOutput = null;
try {
// copy zeppelin_pyspark.py
File scriptFile = File.createTempFile("zeppelin_pyspark-", ".py");
this.scriptPath = scriptFile.getAbsolutePath();
pysparkScriptOutput = new FileOutputStream(scriptFile);
IOUtils.copy(
getClass().getClassLoader().getResourceAsStream("python/zeppelin_pyspark.py"),
pysparkScriptOutput);
// copy zeppelin_context.py to the same folder of zeppelin_pyspark.py
zeppelinContextOutput = new FileOutputStream(scriptFile.getParent() + "/zeppelin_context.py");
IOUtils.copy(
getClass().getClassLoader().getResourceAsStream("python/zeppelin_context.py"),
zeppelinContextOutput);
LOGGER.info("PySpark script {} {} is created",
scriptPath, scriptFile.getParent() + "/zeppelin_context.py");
} finally {
if (pysparkScriptOutput != null) {
try {
pysparkScriptOutput.close();
} catch (IOException e) {
// ignore
}
}
if (zeppelinContextOutput != null) {
try {
zeppelinContextOutput.close();
} catch (IOException e) {
// ignore
}
if (!useIPython()) {
// Initialize Spark in Python Process
try {
bootstrapInterpreter("python/zeppelin_pyspark.py");
} catch (IOException e) {
throw new InterpreterException("Fail to bootstrap pyspark", e);
}
}
}
private Map<String, String> setupPySparkEnv() throws IOException {
Map<String, String> env = EnvironmentUtils.getProcEnvironment();
// only set PYTHONPATH in local or yarn-client mode.
// yarn-cluster will setup PYTHONPATH automatically.
SparkConf conf = null;
try {
conf = getSparkConf();
} catch (InterpreterException e) {
throw new IOException(e);
}
if (!conf.get("spark.submit.deployMode", "client").equals("cluster")) {
if (!env.containsKey("PYTHONPATH")) {
env.put("PYTHONPATH", PythonUtils.sparkPythonPath());
} else {
env.put("PYTHONPATH", PythonUtils.sparkPythonPath() + ":" + env.get("PYTHONPATH"));
}
@Override
public void close() throws InterpreterException {
super.close();
if (sparkInterpreter != null) {
sparkInterpreter.close();
}
}
// get additional class paths when using SPARK_SUBMIT and not using YARN-CLIENT
// also, add all packages to PYTHONPATH since there might be transitive dependencies
if (SparkInterpreter.useSparkSubmit() &&
!sparkInterpreter.isYarnMode()) {
String sparkSubmitJars = conf.get("spark.jars").replace(",", ":");
if (!StringUtils.isEmpty(sparkSubmitJars)) {
env.put("PYTHONPATH", env.get("PYTHONPATH") + ":" + sparkSubmitJars);
}
}
@Override
protected BaseZeppelinContext createZeppelinContext() {
return sparkInterpreter.getZeppelinContext();
}
// set PYSPARK_PYTHON
if (conf.contains("spark.pyspark.python")) {
env.put("PYSPARK_PYTHON", conf.get("spark.pyspark.python"));
}
LOGGER.info("PYTHONPATH: " + env.get("PYTHONPATH"));
return env;
@Override
public InterpreterResult interpret(String st, InterpreterContext context)
throws InterpreterException {
sparkInterpreter.populateSparkWebUrl(context);
return super.interpret(st, context);
}
@Override
protected void preCallPython(InterpreterContext context) {
String jobGroup = Utils.buildJobGroupId(context);
String jobDesc = "Started by: " + Utils.getUserName(context.getAuthenticationInfo());
callPython(new PythonInterpretRequest(
String.format("if 'sc' in locals():\n\tsc.setJobGroup('%s', '%s')", jobGroup, jobDesc),
false));
}
// Run python shell
// Choose python in the order of
// PYSPARK_DRIVER_PYTHON > PYSPARK_PYTHON > zeppelin.pyspark.python
public static String getPythonExec(Properties properties) {
String pythonExec = properties.getProperty("zeppelin.pyspark.python", "python");
@Override
protected String getPythonExec() {
String pythonExec = getProperty("zeppelin.pyspark.python", "python");
if (System.getenv("PYSPARK_PYTHON") != null) {
pythonExec = System.getenv("PYSPARK_PYTHON");
}
@ -289,344 +194,16 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
}
@Override
public void close() throws InterpreterException {
if (iPySparkInterpreter != null) {
iPySparkInterpreter.close();
return;
protected IPythonInterpreter getIPythonInterpreter() {
IPySparkInterpreter iPython = null;
Interpreter p = getInterpreterInTheSameSessionByClassName(IPySparkInterpreter.class.getName());
while (p instanceof WrappedInterpreter) {
p = ((WrappedInterpreter) p).getInnerInterpreter();
}
executor.getWatchdog().destroyProcess();
gatewayServer.shutdown();
iPython = (IPySparkInterpreter) p;
return iPython;
}
private PythonInterpretRequest pythonInterpretRequest = null;
private Integer statementSetNotifier = new Integer(0);
private String statementOutput = null;
private boolean statementError = false;
private Integer statementFinishedNotifier = new Integer(0);
/**
* Request send to Python Daemon
*/
public class PythonInterpretRequest {
public String statements;
public String jobGroup;
public String jobDescription;
public boolean isForCompletion;
public PythonInterpretRequest(String statements, String jobGroup,
String jobDescription, boolean isForCompletion) {
this.statements = statements;
this.jobGroup = jobGroup;
this.jobDescription = jobDescription;
this.isForCompletion = isForCompletion;
}
public String statements() {
return statements;
}
public String jobGroup() {
return jobGroup;
}
public String jobDescription() {
return jobDescription;
}
public boolean isForCompletion() {
return isForCompletion;
}
}
// called by Python Process
public PythonInterpretRequest getStatements() {
synchronized (statementSetNotifier) {
while (pythonInterpretRequest == null) {
try {
statementSetNotifier.wait(1000);
} catch (InterruptedException e) {
}
}
PythonInterpretRequest req = pythonInterpretRequest;
pythonInterpretRequest = null;
return req;
}
}
// called by Python Process
public void setStatementsFinished(String out, boolean error) {
synchronized (statementFinishedNotifier) {
LOGGER.debug("Setting python statement output: " + out + ", error: " + error);
statementOutput = out;
statementError = error;
statementFinishedNotifier.notify();
}
}
private boolean pythonScriptInitialized = false;
private Integer pythonScriptInitializeNotifier = new Integer(0);
// called by Python Process
public void onPythonScriptInitialized(long pid) {
pythonPid = pid;
synchronized (pythonScriptInitializeNotifier) {
LOGGER.debug("onPythonScriptInitialized is called");
pythonScriptInitialized = true;
pythonScriptInitializeNotifier.notifyAll();
}
}
// called by Python Process
public void appendOutput(String message) throws IOException {
LOGGER.debug("Output from python process: " + message);
outputStream.getInterpreterOutput().write(message);
}
@Override
public InterpreterResult interpret(String st, InterpreterContext context)
throws InterpreterException {
if (iPySparkInterpreter != null) {
return iPySparkInterpreter.interpret(st, context);
}
if (sparkInterpreter.isUnsupportedSparkVersion()) {
return new InterpreterResult(Code.ERROR, "Spark "
+ sparkInterpreter.getSparkVersion().toString() + " is not supported");
}
sparkInterpreter.populateSparkWebUrl(context);
if (!pythonscriptRunning) {
return new InterpreterResult(Code.ERROR, "python process not running "
+ outputStream.toString());
}
outputStream.setInterpreterOutput(context.out);
synchronized (pythonScriptInitializeNotifier) {
long startTime = System.currentTimeMillis();
while (pythonScriptInitialized == false
&& pythonscriptRunning
&& System.currentTimeMillis() - startTime < MAX_TIMEOUT_SEC * 1000) {
try {
LOGGER.info("Wait for PythonScript running");
pythonScriptInitializeNotifier.wait(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
List<InterpreterResultMessage> errorMessage;
try {
context.out.flush();
errorMessage = context.out.toInterpreterResultMessage();
} catch (IOException e) {
throw new InterpreterException(e);
}
if (pythonscriptRunning == false) {
// python script failed to initialize and terminated
errorMessage.add(new InterpreterResultMessage(
InterpreterResult.Type.TEXT, "Failed to start PySpark"));
return new InterpreterResult(Code.ERROR, errorMessage);
}
if (pythonScriptInitialized == false) {
// timeout. didn't get initialized message
errorMessage.add(new InterpreterResultMessage(
InterpreterResult.Type.TEXT, "Failed to initialize PySpark"));
return new InterpreterResult(Code.ERROR, errorMessage);
}
//TODO(zjffdu) remove this as PySpark is supported starting from spark 1.2s
if (!sparkInterpreter.getSparkVersion().isPysparkSupported()) {
errorMessage.add(new InterpreterResultMessage(
InterpreterResult.Type.TEXT,
"pyspark " + sparkInterpreter.getSparkContext().version() + " is not supported"));
return new InterpreterResult(Code.ERROR, errorMessage);
}
String jobGroup = Utils.buildJobGroupId(context);
String jobDesc = "Started by: " + Utils.getUserName(context.getAuthenticationInfo());
SparkZeppelinContext z = sparkInterpreter.getZeppelinContext();
z.setInterpreterContext(context);
z.setGui(context.getGui());
z.setNoteGui(context.getNoteGui());
InterpreterContext.set(context);
pythonInterpretRequest = new PythonInterpretRequest(st, jobGroup, jobDesc, false);
statementOutput = null;
synchronized (statementSetNotifier) {
statementSetNotifier.notify();
}
synchronized (statementFinishedNotifier) {
while (statementOutput == null) {
try {
statementFinishedNotifier.wait(1000);
} catch (InterruptedException e) {
}
}
}
if (statementError) {
return new InterpreterResult(Code.ERROR, statementOutput);
} else {
try {
context.out.flush();
} catch (IOException e) {
throw new InterpreterException(e);
}
return new InterpreterResult(Code.SUCCESS);
}
}
public void interrupt() throws IOException, InterpreterException {
if (pythonPid > -1) {
LOGGER.info("Sending SIGINT signal to PID : " + pythonPid);
Runtime.getRuntime().exec("kill -SIGINT " + pythonPid);
} else {
LOGGER.warn("Non UNIX/Linux system, close the interpreter");
close();
}
}
@Override
public void cancel(InterpreterContext context) throws InterpreterException {
if (iPySparkInterpreter != null) {
iPySparkInterpreter.cancel(context);
return;
}
SparkInterpreter sparkInterpreter = getSparkInterpreter();
sparkInterpreter.cancel(context);
try {
interrupt();
} catch (IOException e) {
LOGGER.error("Error", e);
}
}
@Override
public FormType getFormType() {
return FormType.NATIVE;
}
@Override
public int getProgress(InterpreterContext context) throws InterpreterException {
if (iPySparkInterpreter != null) {
return iPySparkInterpreter.getProgress(context);
}
SparkInterpreter sparkInterpreter = getSparkInterpreter();
return sparkInterpreter.getProgress(context);
}
@Override
public List<InterpreterCompletion> completion(String buf, int cursor,
InterpreterContext interpreterContext)
throws InterpreterException {
if (iPySparkInterpreter != null) {
return iPySparkInterpreter.completion(buf, cursor, interpreterContext);
}
if (buf.length() < cursor) {
cursor = buf.length();
}
String completionString = getCompletionTargetString(buf, cursor);
String completionCommand = "completion.getCompletion('" + completionString + "')";
LOGGER.debug("completionCommand: " + completionCommand);
//start code for completion
if (sparkInterpreter.isUnsupportedSparkVersion() || pythonscriptRunning == false) {
return new LinkedList<>();
}
pythonInterpretRequest = new PythonInterpretRequest(completionCommand, "", "", true);
statementOutput = null;
synchronized (statementSetNotifier) {
statementSetNotifier.notify();
}
String[] completionList = null;
synchronized (statementFinishedNotifier) {
long startTime = System.currentTimeMillis();
while (statementOutput == null
&& pythonscriptRunning) {
try {
if (System.currentTimeMillis() - startTime > MAX_TIMEOUT_SEC * 1000) {
LOGGER.error("pyspark completion didn't have response for {}sec.", MAX_TIMEOUT_SEC);
break;
}
statementFinishedNotifier.wait(1000);
} catch (InterruptedException e) {
// not working
LOGGER.info("wait drop");
return new LinkedList<>();
}
}
if (statementError) {
return new LinkedList<>();
}
Gson gson = new Gson();
completionList = gson.fromJson(statementOutput, String[].class);
}
//end code for completion
if (completionList == null) {
return new LinkedList<>();
}
List<InterpreterCompletion> results = new LinkedList<>();
for (String name: completionList) {
results.add(new InterpreterCompletion(name, name, StringUtils.EMPTY));
LOGGER.debug("completion: " + name);
}
return results;
}
private String getCompletionTargetString(String text, int cursor) {
String[] completionSeqCharaters = {" ", "\n", "\t"};
int completionEndPosition = cursor;
int completionStartPosition = cursor;
int indexOfReverseSeqPostion = cursor;
String resultCompletionText = "";
String completionScriptText = "";
try {
completionScriptText = text.substring(0, cursor);
}
catch (Exception e) {
LOGGER.error(e.toString());
return null;
}
completionEndPosition = completionScriptText.length();
String tempReverseCompletionText = new StringBuilder(completionScriptText).reverse().toString();
for (String seqCharacter : completionSeqCharaters) {
indexOfReverseSeqPostion = tempReverseCompletionText.indexOf(seqCharacter);
if (indexOfReverseSeqPostion < completionStartPosition && indexOfReverseSeqPostion > 0) {
completionStartPosition = indexOfReverseSeqPostion;
}
}
if (completionStartPosition == completionEndPosition) {
completionStartPosition = 0;
}
else
{
completionStartPosition = completionEndPosition - completionStartPosition;
}
resultCompletionText = completionScriptText.substring(
completionStartPosition , completionEndPosition);
return resultCompletionText;
}
private SparkInterpreter getSparkInterpreter() throws InterpreterException {
LazyOpenInterpreter lazy = null;
SparkInterpreter spark = null;
@ -646,63 +223,45 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
return spark;
}
private IPySparkInterpreter getIPySparkInterpreter() {
LazyOpenInterpreter lazy = null;
IPySparkInterpreter iPySpark = null;
Interpreter p = getInterpreterInTheSameSessionByClassName(IPySparkInterpreter.class.getName());
while (p instanceof WrappedInterpreter) {
if (p instanceof LazyOpenInterpreter) {
lazy = (LazyOpenInterpreter) p;
}
p = ((WrappedInterpreter) p).getInnerInterpreter();
}
iPySpark = (IPySparkInterpreter) p;
return iPySpark;
}
public SparkZeppelinContext getZeppelinContext() throws InterpreterException {
SparkInterpreter sparkIntp = getSparkInterpreter();
if (sparkIntp != null) {
return getSparkInterpreter().getZeppelinContext();
public SparkZeppelinContext getZeppelinContext() {
if (sparkInterpreter != null) {
return sparkInterpreter.getZeppelinContext();
} else {
return null;
}
}
public JavaSparkContext getJavaSparkContext() throws InterpreterException {
SparkInterpreter intp = getSparkInterpreter();
if (intp == null) {
public JavaSparkContext getJavaSparkContext() {
if (sparkInterpreter == null) {
return null;
} else {
return new JavaSparkContext(intp.getSparkContext());
return new JavaSparkContext(sparkInterpreter.getSparkContext());
}
}
public Object getSparkSession() throws InterpreterException {
SparkInterpreter intp = getSparkInterpreter();
if (intp == null) {
public Object getSparkSession() {
if (sparkInterpreter == null) {
return null;
} else {
return intp.getSparkSession();
return sparkInterpreter.getSparkSession();
}
}
public SparkConf getSparkConf() throws InterpreterException {
public SparkConf getSparkConf() {
JavaSparkContext sc = getJavaSparkContext();
if (sc == null) {
return null;
} else {
return getJavaSparkContext().getConf();
return sc.getConf();
}
}
public SQLContext getSQLContext() throws InterpreterException {
SparkInterpreter intp = getSparkInterpreter();
if (intp == null) {
public SQLContext getSQLContext() {
if (sparkInterpreter == null) {
return null;
} else {
return intp.getSQLContext();
return sparkInterpreter.getSQLContext();
}
}
@ -718,21 +277,7 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
return (DepInterpreter) p;
}
@Override
public void onProcessComplete(int exitValue) {
pythonscriptRunning = false;
LOGGER.info("python process terminated. exit code " + exitValue);
}
@Override
public void onProcessFailed(ExecuteException e) {
pythonscriptRunning = false;
LOGGER.error("python process failed", e);
}
// Called by Python Process, used for debugging purpose
public void logPythonOutput(String message) {
LOGGER.debug("Python Process Output: " + message);
public boolean isSpark2() {
return sparkInterpreter.getSparkVersion().newerThanEquals(SparkVersion.SPARK_2_0_0);
}
}

View file

@ -15,150 +15,43 @@
# limitations under the License.
#
import os, sys, getopt, traceback, json, re
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
from py4j.protocol import Py4JJavaError
from py4j.java_gateway import java_import
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
import ast
import warnings
# for back compatibility
from pyspark.sql import SQLContext, HiveContext, Row
class Logger(object):
def __init__(self):
pass
def write(self, message):
intp.appendOutput(message)
def reset(self):
pass
def flush(self):
pass
class SparkVersion(object):
SPARK_1_4_0 = 10400
SPARK_1_3_0 = 10300
SPARK_2_0_0 = 20000
def __init__(self, versionNumber):
self.version = versionNumber
def isAutoConvertEnabled(self):
return self.version >= self.SPARK_1_4_0
def isImportAllPackageUnderSparkSql(self):
return self.version >= self.SPARK_1_3_0
def isSpark2(self):
return self.version >= self.SPARK_2_0_0
class PySparkCompletion:
def __init__(self, interpreterObject):
self.interpreterObject = interpreterObject
def getGlobalCompletion(self, text_value):
completions = [completion for completion in list(globals().keys()) if completion.startswith(text_value)]
return completions
def getMethodCompletion(self, objName, methodName):
execResult = locals()
try:
exec("{} = dir({})".format("objectDefList", objName), globals(), execResult)
except:
return None
else:
objectDefList = execResult['objectDefList']
return [completion for completion in execResult['objectDefList'] if completion.startswith(methodName)]
def getCompletion(self, text_value):
if text_value == None:
return None
dotPos = text_value.find(".")
if dotPos == -1:
objName = text_value
completionList = self.getGlobalCompletion(objName)
else:
objName = text_value[:dotPos]
methodName = text_value[dotPos + 1:]
completionList = self.getMethodCompletion(objName, methodName)
if len(completionList) <= 0:
self.interpreterObject.setStatementsFinished("", False)
else:
result = json.dumps(list(filter(lambda x : not re.match("^__.*", x), list(completionList))))
self.interpreterObject.setStatementsFinished(result, False)
client = GatewayClient(port=int(sys.argv[1]))
sparkVersion = SparkVersion(int(sys.argv[2]))
if sparkVersion.isSpark2():
intp = gateway.entry_point
isSpark2 = intp.isSpark2()
if isSpark2:
from pyspark.sql import SparkSession
else:
from pyspark.sql import SchemaRDD
if sparkVersion.isAutoConvertEnabled():
gateway = JavaGateway(client, auto_convert = True)
else:
gateway = JavaGateway(client)
jsc = intp.getJavaSparkContext()
java_import(gateway.jvm, "org.apache.spark.SparkEnv")
java_import(gateway.jvm, "org.apache.spark.SparkConf")
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
intp = gateway.entry_point
output = Logger()
sys.stdout = output
sys.stderr = output
jsc = intp.getJavaSparkContext()
if sparkVersion.isImportAllPackageUnderSparkSql():
java_import(gateway.jvm, "org.apache.spark.sql.*")
java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
else:
java_import(gateway.jvm, "org.apache.spark.sql.SQLContext")
java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext")
java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext")
java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext")
java_import(gateway.jvm, "org.apache.spark.sql.*")
java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
java_import(gateway.jvm, "scala.Tuple2")
_zcUserQueryNameSpace = {}
jconf = intp.getSparkConf()
conf = SparkConf(_jvm = gateway.jvm, _jconf = jconf)
sc = _zsc_ = SparkContext(jsc=jsc, gateway=gateway, conf=conf)
_zcUserQueryNameSpace["_zsc_"] = _zsc_
_zcUserQueryNameSpace["sc"] = sc
if sparkVersion.isSpark2():
if isSpark2:
spark = __zSpark__ = SparkSession(sc, intp.getSparkSession())
sqlc = __zSqlc__ = __zSpark__._wrapped
_zcUserQueryNameSpace["sqlc"] = sqlc
_zcUserQueryNameSpace["__zSqlc__"] = __zSqlc__
_zcUserQueryNameSpace["spark"] = spark
_zcUserQueryNameSpace["__zSpark__"] = __zSpark__
else:
sqlc = __zSqlc__ = SQLContext(sparkContext=sc, sqlContext=intp.getSQLContext())
_zcUserQueryNameSpace["sqlc"] = sqlc
_zcUserQueryNameSpace["__zSqlc__"] = sqlc
sqlContext = __zSqlc__
_zcUserQueryNameSpace["sqlContext"] = sqlContext
completion = __zeppelin_completion__ = PySparkCompletion(intp)
_zcUserQueryNameSpace["completion"] = completion
_zcUserQueryNameSpace["__zeppelin_completion__"] = __zeppelin_completion__
from zeppelin_context import PyZeppelinContext
@ -176,92 +69,4 @@ class PySparkZeppelinContext(PyZeppelinContext):
super(PySparkZeppelinContext, self).show(obj)
z = __zeppelin__ = PySparkZeppelinContext(intp.getZeppelinContext(), gateway)
__zeppelin__._setup_matplotlib()
_zcUserQueryNameSpace["z"] = z
_zcUserQueryNameSpace["__zeppelin__"] = __zeppelin__
intp.onPythonScriptInitialized(os.getpid())
while True :
req = intp.getStatements()
try:
stmts = req.statements().split("\n")
jobGroup = req.jobGroup()
jobDesc = req.jobDescription()
isForCompletion = req.isForCompletion()
# Get post-execute hooks
try:
global_hook = intp.getHook('post_exec_dev')
except:
global_hook = None
try:
user_hook = __zeppelin__.getHook('post_exec')
except:
user_hook = None
nhooks = 0
if not isForCompletion:
for hook in (global_hook, user_hook):
if hook:
nhooks += 1
if stmts:
# use exec mode to compile the statements except the last statement,
# so that the last statement's evaluation will be printed to stdout
sc.setJobGroup(jobGroup, jobDesc)
code = compile('\n'.join(stmts), '<stdin>', 'exec', ast.PyCF_ONLY_AST, 1)
to_run_hooks = []
if (nhooks > 0):
to_run_hooks = code.body[-nhooks:]
to_run_exec, to_run_single = (code.body[:-(nhooks + 1)],
[code.body[-(nhooks + 1)]])
try:
for node in to_run_exec:
mod = ast.Module([node])
code = compile(mod, '<stdin>', 'exec')
exec(code, _zcUserQueryNameSpace)
for node in to_run_single:
mod = ast.Interactive([node])
code = compile(mod, '<stdin>', 'single')
exec(code, _zcUserQueryNameSpace)
for node in to_run_hooks:
mod = ast.Module([node])
code = compile(mod, '<stdin>', 'exec')
exec(code, _zcUserQueryNameSpace)
if not isForCompletion:
# only call it when it is not for code completion. code completion will call it in
# PySparkCompletion.getCompletion
intp.setStatementsFinished("", False)
except Py4JJavaError:
# raise it to outside try except
raise
except:
if not isForCompletion:
exception = traceback.format_exc()
m = re.search("File \"<stdin>\", line (\d+).*", exception)
if m:
line_no = int(m.group(1))
intp.setStatementsFinished(
"Fail to execute line {}: {}\n".format(line_no, stmts[line_no - 1]) + exception, True)
else:
intp.setStatementsFinished(exception, True)
else:
intp.setStatementsFinished("", False)
except Py4JJavaError:
excInnerError = traceback.format_exc() # format_tb() does not return the inner exception
innerErrorStart = excInnerError.find("Py4JJavaError:")
if innerErrorStart > -1:
excInnerError = excInnerError[innerErrorStart:]
intp.setStatementsFinished(excInnerError + str(sys.exc_info()), True)
except:
intp.setStatementsFinished(traceback.format_exc(), True)
output.reset()

View file

@ -27,18 +27,16 @@ import org.apache.zeppelin.interpreter.InterpreterGroup;
import org.apache.zeppelin.interpreter.InterpreterOutput;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterResultMessage;
import org.apache.zeppelin.interpreter.LazyOpenInterpreter;
import org.apache.zeppelin.interpreter.remote.RemoteEventClient;
import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
import org.apache.zeppelin.python.IPythonInterpreterTest;
import org.apache.zeppelin.user.AuthenticationInfo;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
@ -46,65 +44,72 @@ import java.util.Properties;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
public class IPySparkInterpreterTest {
public class IPySparkInterpreterTest extends IPythonInterpreterTest {
private IPySparkInterpreter iPySparkInterpreter;
private InterpreterGroup intpGroup;
private RemoteEventClient mockRemoteEventClient = mock(RemoteEventClient.class);
@Before
public void setup() throws InterpreterException {
@Override
protected Properties initIntpProperties() {
Properties p = new Properties();
p.setProperty("spark.master", "local[4]");
p.setProperty("master", "local[4]");
p.setProperty("spark.submit.deployMode", "client");
p.setProperty("spark.app.name", "Zeppelin Test");
p.setProperty("zeppelin.spark.useHiveContext", "true");
p.setProperty("zeppelin.spark.useHiveContext", "false");
p.setProperty("zeppelin.spark.maxResult", "3");
p.setProperty("zeppelin.spark.importImplicit", "true");
p.setProperty("zeppelin.spark.useNew", "true");
p.setProperty("zeppelin.pyspark.python", "python");
p.setProperty("zeppelin.dep.localrepo", Files.createTempDir().getAbsolutePath());
p.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1");
return p;
}
@Override
protected void startInterpreter(Properties properties) throws InterpreterException {
intpGroup = new InterpreterGroup();
intpGroup.put("session_1", new LinkedList<Interpreter>());
intpGroup.put("session_1", new ArrayList<Interpreter>());
SparkInterpreter sparkInterpreter = new SparkInterpreter(p);
LazyOpenInterpreter sparkInterpreter = new LazyOpenInterpreter(
new SparkInterpreter(properties));
intpGroup.get("session_1").add(sparkInterpreter);
sparkInterpreter.setInterpreterGroup(intpGroup);
sparkInterpreter.open();
sparkInterpreter.getZeppelinContext().setEventClient(mockRemoteEventClient);
iPySparkInterpreter = new IPySparkInterpreter(p);
intpGroup.get("session_1").add(iPySparkInterpreter);
iPySparkInterpreter.setInterpreterGroup(intpGroup);
iPySparkInterpreter.open();
sparkInterpreter.getZeppelinContext().setEventClient(mockRemoteEventClient);
LazyOpenInterpreter pySparkInterpreter =
new LazyOpenInterpreter(new PySparkInterpreter(properties));
intpGroup.get("session_1").add(pySparkInterpreter);
pySparkInterpreter.setInterpreterGroup(intpGroup);
interpreter = new LazyOpenInterpreter(new IPySparkInterpreter(properties));
intpGroup.get("session_1").add(interpreter);
interpreter.setInterpreterGroup(intpGroup);
interpreter.open();
}
@After
@Override
public void tearDown() throws InterpreterException {
if (iPySparkInterpreter != null) {
iPySparkInterpreter.close();
}
intpGroup.close();
interpreter = null;
intpGroup = null;
}
@Test
public void testBasics() throws InterruptedException, IOException, InterpreterException {
// all the ipython test should pass too.
IPythonInterpreterTest.testInterpreter(iPySparkInterpreter);
testPySpark(iPySparkInterpreter, mockRemoteEventClient);
public void testIPySpark() throws InterruptedException, InterpreterException, IOException {
testPySpark(interpreter, mockRemoteEventClient);
}
public static void testPySpark(final Interpreter interpreter, RemoteEventClient mockRemoteEventClient)
throws InterpreterException, IOException, InterruptedException {
reset(mockRemoteEventClient);
// rdd
InterpreterContext context = getInterpreterContext(mockRemoteEventClient);
InterpreterContext context = createInterpreterContext(mockRemoteEventClient);
InterpreterResult result = interpreter.interpret("sc.version", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
@ -112,17 +117,17 @@ public class IPySparkInterpreterTest {
// spark url is sent
verify(mockRemoteEventClient).onMetaInfosReceived(any(Map.class));
context = getInterpreterContext(mockRemoteEventClient);
context = createInterpreterContext(mockRemoteEventClient);
result = interpreter.interpret("sc.range(1,10).sum()", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals("45", interpreterResultMessages.get(0).getData().trim());
// spark job url is sent
verify(mockRemoteEventClient).onParaInfosReceived(any(String.class), any(String.class), any(Map.class));
// verify(mockRemoteEventClient).onParaInfosReceived(any(String.class), any(String.class), any(Map.class));
// spark sql
context = getInterpreterContext(mockRemoteEventClient);
context = createInterpreterContext(mockRemoteEventClient);
if (!isSpark2(sparkVersion)) {
result = interpreter.interpret("df = sqlContext.createDataFrame([(1,'a'),(2,'b')])\ndf.show()", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
@ -135,7 +140,7 @@ public class IPySparkInterpreterTest {
"| 2| b|\n" +
"+---+---+", interpreterResultMessages.get(0).getData().trim());
context = getInterpreterContext(mockRemoteEventClient);
context = createInterpreterContext(mockRemoteEventClient);
result = interpreter.interpret("z.show(df)", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
@ -155,7 +160,7 @@ public class IPySparkInterpreterTest {
"| 2| b|\n" +
"+---+---+", interpreterResultMessages.get(0).getData().trim());
context = getInterpreterContext(mockRemoteEventClient);
context = createInterpreterContext(mockRemoteEventClient);
result = interpreter.interpret("z.show(df)", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
@ -166,7 +171,7 @@ public class IPySparkInterpreterTest {
}
// cancel
if (interpreter instanceof IPySparkInterpreter) {
final InterpreterContext context2 = getInterpreterContext(mockRemoteEventClient);
final InterpreterContext context2 = createInterpreterContext(mockRemoteEventClient);
Thread thread = new Thread() {
@Override
@ -196,24 +201,24 @@ public class IPySparkInterpreterTest {
}
// completions
List<InterpreterCompletion> completions = interpreter.completion("sc.ran", 6, getInterpreterContext(mockRemoteEventClient));
List<InterpreterCompletion> completions = interpreter.completion("sc.ran", 6, createInterpreterContext(mockRemoteEventClient));
assertEquals(1, completions.size());
assertEquals("range", completions.get(0).getValue());
completions = interpreter.completion("sc.", 3, getInterpreterContext(mockRemoteEventClient));
completions = interpreter.completion("sc.", 3, createInterpreterContext(mockRemoteEventClient));
assertTrue(completions.size() > 0);
completions.contains(new InterpreterCompletion("range", "range", ""));
completions = interpreter.completion("1+1\nsc.", 7, getInterpreterContext(mockRemoteEventClient));
completions = interpreter.completion("1+1\nsc.", 7, createInterpreterContext(mockRemoteEventClient));
assertTrue(completions.size() > 0);
completions.contains(new InterpreterCompletion("range", "range", ""));
completions = interpreter.completion("s", 1, getInterpreterContext(mockRemoteEventClient));
completions = interpreter.completion("s", 1, createInterpreterContext(mockRemoteEventClient));
assertTrue(completions.size() > 0);
completions.contains(new InterpreterCompletion("sc", "sc", ""));
// pyspark streaming
context = getInterpreterContext(mockRemoteEventClient);
context = createInterpreterContext(mockRemoteEventClient);
result = interpreter.interpret(
"from pyspark.streaming import StreamingContext\n" +
"import time\n" +
@ -239,7 +244,7 @@ public class IPySparkInterpreterTest {
return sparkVersion.startsWith("'2.") || sparkVersion.startsWith("u'2.");
}
private static InterpreterContext getInterpreterContext(RemoteEventClient mockRemoteEventClient) {
private static InterpreterContext createInterpreterContext(RemoteEventClient mockRemoteEventClient) {
InterpreterContext context = new InterpreterContext(
"noteId",
"paragraphId",

View file

@ -127,7 +127,7 @@ public class OldSparkInterpreterTest {
new LocalResourcePool("id"),
new LinkedList<InterpreterContextRunner>(),
new InterpreterOutput(null)) {
@Override
public RemoteEventClientWrapper getClient() {
return remoteEventClientWrapper;
@ -192,7 +192,7 @@ public class OldSparkInterpreterTest {
public void testEndWithComment() throws InterpreterException {
assertEquals(InterpreterResult.Code.SUCCESS, repl.interpret("val c=1\n//comment", context).code());
}
@Test
public void testCreateDataFrame() throws InterpreterException {
if (getSparkVersionNumber(repl) >= 13) {

View file

@ -17,154 +17,73 @@
package org.apache.zeppelin.spark;
import org.apache.zeppelin.display.AngularObjectRegistry;
import org.apache.zeppelin.display.GUI;
import org.apache.zeppelin.interpreter.*;
import com.google.common.io.Files;
import org.apache.zeppelin.interpreter.Interpreter;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterException;
import org.apache.zeppelin.interpreter.InterpreterGroup;
import org.apache.zeppelin.interpreter.LazyOpenInterpreter;
import org.apache.zeppelin.interpreter.remote.RemoteEventClient;
import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
import org.apache.zeppelin.resource.LocalResourcePool;
import org.apache.zeppelin.user.AuthenticationInfo;
import org.junit.*;
import org.junit.rules.TemporaryFolder;
import org.junit.runners.MethodSorters;
import org.apache.zeppelin.python.PythonInterpreterTest;
import org.junit.Test;
import java.io.IOException;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static org.junit.Assert.*;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
public class PySparkInterpreterTest {
public class PySparkInterpreterTest extends PythonInterpreterTest {
@ClassRule
public static TemporaryFolder tmpDir = new TemporaryFolder();
static SparkInterpreter sparkInterpreter;
static PySparkInterpreter pySparkInterpreter;
static InterpreterGroup intpGroup;
static InterpreterContext context;
private RemoteEventClient mockRemoteEventClient = mock(RemoteEventClient.class);
private static Properties getPySparkTestProperties() throws IOException {
Properties p = new Properties();
p.setProperty("spark.master", "local");
p.setProperty("spark.app.name", "Zeppelin Test");
p.setProperty("zeppelin.spark.useHiveContext", "true");
p.setProperty("zeppelin.spark.maxResult", "1000");
p.setProperty("zeppelin.spark.importImplicit", "true");
p.setProperty("zeppelin.pyspark.python", "python");
p.setProperty("zeppelin.dep.localrepo", tmpDir.newFolder().getAbsolutePath());
p.setProperty("zeppelin.pyspark.useIPython", "false");
p.setProperty("zeppelin.spark.test", "true");
return p;
}
@Override
public void setUp() throws InterpreterException {
Properties properties = new Properties();
properties.setProperty("spark.master", "local");
properties.setProperty("spark.app.name", "Zeppelin Test");
properties.setProperty("zeppelin.spark.useHiveContext", "false");
properties.setProperty("zeppelin.spark.maxResult", "3");
properties.setProperty("zeppelin.spark.importImplicit", "true");
properties.setProperty("zeppelin.pyspark.python", "python");
properties.setProperty("zeppelin.dep.localrepo", Files.createTempDir().getAbsolutePath());
properties.setProperty("zeppelin.pyspark.useIPython", "false");
properties.setProperty("zeppelin.spark.useNew", "true");
properties.setProperty("zeppelin.spark.test", "true");
properties.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1");
/**
* Get spark version number as a numerical value.
* eg. 1.1.x => 11, 1.2.x => 12, 1.3.x => 13 ...
*/
public static int getSparkVersionNumber() {
if (sparkInterpreter == null) {
return 0;
}
String[] split = sparkInterpreter.getSparkContext().version().split("\\.");
int version = Integer.parseInt(split[0]) * 10 + Integer.parseInt(split[1]);
return version;
}
@BeforeClass
public static void setUp() throws Exception {
InterpreterContext.set(getInterpreterContext(mockRemoteEventClient));
// create interpreter group
intpGroup = new InterpreterGroup();
intpGroup.put("note", new LinkedList<Interpreter>());
context = new InterpreterContext("note", "id", null, "title", "text",
new AuthenticationInfo(),
new HashMap<String, Object>(),
new GUI(),
new GUI(),
new AngularObjectRegistry(intpGroup.getId(), null),
new LocalResourcePool("id"),
new LinkedList<InterpreterContextRunner>(),
new InterpreterOutput(null));
InterpreterContext.set(context);
sparkInterpreter = new SparkInterpreter(getPySparkTestProperties());
LazyOpenInterpreter sparkInterpreter =
new LazyOpenInterpreter(new SparkInterpreter(properties));
intpGroup.get("note").add(sparkInterpreter);
sparkInterpreter.setInterpreterGroup(intpGroup);
sparkInterpreter.open();
pySparkInterpreter = new PySparkInterpreter(getPySparkTestProperties());
intpGroup.get("note").add(pySparkInterpreter);
pySparkInterpreter.setInterpreterGroup(intpGroup);
pySparkInterpreter.open();
LazyOpenInterpreter iPySparkInterpreter =
new LazyOpenInterpreter(new IPySparkInterpreter(properties));
intpGroup.get("note").add(iPySparkInterpreter);
iPySparkInterpreter.setInterpreterGroup(intpGroup);
interpreter = new LazyOpenInterpreter(new PySparkInterpreter(properties));
intpGroup.get("note").add(interpreter);
interpreter.setInterpreterGroup(intpGroup);
interpreter.open();
}
@AfterClass
public static void tearDown() throws InterpreterException {
pySparkInterpreter.close();
sparkInterpreter.close();
@Override
public void tearDown() throws InterpreterException {
intpGroup.close();
intpGroup = null;
interpreter = null;
}
@Test
public void testBasicIntp() throws InterpreterException, InterruptedException, IOException {
IPySparkInterpreterTest.testPySpark(pySparkInterpreter, mockRemoteEventClient);
public void testPySpark() throws InterruptedException, InterpreterException, IOException {
IPySparkInterpreterTest.testPySpark(interpreter, mockRemoteEventClient);
}
@Test
public void testRedefinitionZeppelinContext() throws InterpreterException {
if (getSparkVersionNumber() > 11) {
String redefinitionCode = "z = 1\n";
String restoreCode = "z = __zeppelin__\n";
String validCode = "z.input(\"test\")\n";
assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(validCode, context).code());
assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(redefinitionCode, context).code());
assertEquals(InterpreterResult.Code.ERROR, pySparkInterpreter.interpret(validCode, context).code());
assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(restoreCode, context).code());
assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(validCode, context).code());
}
}
private class infinityPythonJob implements Runnable {
@Override
public void run() {
String code = "import time\nwhile True:\n time.sleep(1)" ;
InterpreterResult ret = null;
try {
ret = pySparkInterpreter.interpret(code, context);
} catch (InterpreterException e) {
e.printStackTrace();
}
assertNotNull(ret);
Pattern expectedMessage = Pattern.compile("KeyboardInterrupt");
Matcher m = expectedMessage.matcher(ret.message().toString());
assertTrue(m.find());
}
}
@Test
public void testCancelIntp() throws InterruptedException, InterpreterException {
if (getSparkVersionNumber() > 11) {
assertEquals(InterpreterResult.Code.SUCCESS,
pySparkInterpreter.interpret("a = 1\n", context).code());
Thread t = new Thread(new infinityPythonJob());
t.start();
Thread.sleep(5000);
pySparkInterpreter.cancel(context);
assertTrue(t.isAlive());
t.join(2000);
assertFalse(t.isAlive());
}
}
}

View file

@ -26,6 +26,8 @@ import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.LazyOpenInterpreter;
import org.apache.zeppelin.interpreter.remote.RemoteEventClient;
import org.apache.zeppelin.user.AuthenticationInfo;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
@ -47,8 +49,8 @@ public class SparkRInterpreterTest {
private SparkInterpreter sparkInterpreter;
private RemoteEventClient mockRemoteEventClient = mock(RemoteEventClient.class);
@Test
public void testSparkRInterpreter() throws InterpreterException, InterruptedException {
@Before
public void setUp() throws InterpreterException {
Properties properties = new Properties();
properties.setProperty("spark.master", "local");
properties.setProperty("spark.app.name", "test");
@ -69,6 +71,16 @@ public class SparkRInterpreterTest {
sparkRInterpreter.open();
sparkInterpreter.getZeppelinContext().setEventClient(mockRemoteEventClient);
}
@After
public void tearDown() throws InterpreterException {
sparkInterpreter.close();
}
@Test
public void testSparkRInterpreter() throws InterpreterException, InterruptedException {
InterpreterResult result = sparkRInterpreter.interpret("1+1", getInterpreterContext());
assertEquals(InterpreterResult.Code.SUCCESS, result.code());

View file

@ -43,9 +43,9 @@ log4j.logger.DataNucleus.Datastore=ERROR
# Log all JDBC parameters
log4j.logger.org.hibernate.type=ALL
log4j.logger.org.apache.zeppelin.interpreter=DEBUG
log4j.logger.org.apache.zeppelin.spark=DEBUG
log4j.logger.org.apache.zeppelin.interpreter=WARN
log4j.logger.org.apache.zeppelin.spark=INFO
log4j.logger.org.apache.zeppelin.python=DEBUG
log4j.logger.org.apache.spark.repl.Main=INFO
log4j.logger.org.apache.spark.repl.Main=WARN

View file

@ -161,4 +161,17 @@ public class InterpreterGroup {
public int hashCode() {
return id != null ? id.hashCode() : 0;
}
public void close() {
for (List<Interpreter> session : sessions.values()) {
for (Interpreter interpreter : session) {
try {
interpreter.close();
} catch (InterpreterException e) {
LOGGER.warn("Fail to close interpreter: " + interpreter.getClassName(), e);
}
}
}
sessions.clear();
}
}

View file

@ -31,9 +31,11 @@ import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.zeppelin.conf.ZeppelinConfiguration;
import org.apache.zeppelin.display.AngularObject;
@ -54,8 +56,16 @@ import org.apache.zeppelin.user.AuthenticationInfo;
*/
@RunWith(value = Parameterized.class)
public class ZeppelinSparkClusterTest extends AbstractTestRestApi {
private static final Logger LOGGER = LoggerFactory.getLogger(ZeppelinSparkClusterTest.class);
//This is for only run setupSparkInterpreter one time for each spark version, otherwise
//each test method will run setupSparkInterpreter which will cost a long time and may cause travis
//ci timeout.
//TODO(zjffdu) remove this after we upgrade it to junit 4.13 (ZEPPELIN-3341)
private static Set<String> verifiedSparkVersions = new HashSet<>();
private String sparkVersion;
private AuthenticationInfo anonymous = new AuthenticationInfo("anonymous");
@ -63,8 +73,11 @@ public class ZeppelinSparkClusterTest extends AbstractTestRestApi {
this.sparkVersion = sparkVersion;
LOGGER.info("Testing SparkVersion: " + sparkVersion);
String sparkHome = SparkDownloadUtils.downloadSpark(sparkVersion);
setupSparkInterpreter(sparkHome);
verifySparkVersionNumber();
if (!verifiedSparkVersions.contains(sparkVersion)) {
verifiedSparkVersions.add(sparkVersion);
setupSparkInterpreter(sparkHome);
verifySparkVersionNumber();
}
}
@Parameterized.Parameters

View file

@ -520,7 +520,8 @@ public class InterpreterSetting {
Map<String, InterpreterProperty> iProperties = (Map<String, InterpreterProperty>) properties;
for (Map.Entry<String, InterpreterProperty> entry : iProperties.entrySet()) {
if (entry.getValue().getValue() != null) {
jProperties.setProperty(entry.getKey().trim(), entry.getValue().getValue().toString().trim());
jProperties.setProperty(entry.getKey().trim(),
entry.getValue().getValue().toString().trim());
}
}