Make python docker interpreter work using py4j

This commit is contained in:
Lee moon soo 2017-03-16 23:10:09 -07:00
parent 8a016c934d
commit 9fcf1446e3
4 changed files with 353 additions and 34 deletions

View file

@ -0,0 +1,200 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zeppelin.python;
import org.apache.zeppelin.interpreter.*;
import org.apache.zeppelin.scheduler.Scheduler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.nio.file.Paths;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Helps run python interpreter on a docker container
*/
public class PythonDockerInterpreter extends Interpreter {
Logger logger = LoggerFactory.getLogger(PythonDockerInterpreter.class);
Pattern activatePattern = Pattern.compile("activate\\s*(.*)");
Pattern deactivatePattern = Pattern.compile("deactivate");
Pattern helpPattern = Pattern.compile("help");
private File zeppelinHome;
public PythonDockerInterpreter(Properties property) {
super(property);
}
@Override
public void open() {
if (System.getenv("ZEPPELIN_HOME") != null) {
zeppelinHome = new File(System.getenv("ZEPPELIN_HOME"));
} else {
zeppelinHome = Paths.get("..").toAbsolutePath().toFile();
}
}
@Override
public void close() {
}
@Override
public InterpreterResult interpret(String st, InterpreterContext context) {
File pythonScript = new File(getPythonInterpreter().getScriptPath());
InterpreterOutput out = context.out;
Matcher activateMatcher = activatePattern.matcher(st);
Matcher deactivateMatcher = deactivatePattern.matcher(st);
Matcher helpMatcher = helpPattern.matcher(st);
if (st == null || st.isEmpty() || helpMatcher.matches()) {
printUsage(out);
return new InterpreterResult(InterpreterResult.Code.SUCCESS);
} else if (activateMatcher.matches()) {
String image = activateMatcher.group(1);
pull(out, image);
// mount pythonscript dir
String mountPythonScript = "-v " +
pythonScript.getParentFile().getAbsolutePath() +
":/_zeppelin_tmp ";
// mount zeppelin dir
String mountPy4j = "-v " +
zeppelinHome.getAbsolutePath() +
":/_zeppelin ";
// set PYTHONPATH
String pythonPath = ":/_zeppelin/" + PythonInterpreter.ZEPPELIN_PY4JPATH + ":" +
":/_zeppelin/" + PythonInterpreter.ZEPPELIN_PYTHON_LIBS;
setPythonCommand("docker run -i --rm " +
mountPythonScript +
mountPy4j +
"-e PYTHONPATH=\"" + pythonPath + "\" " +
image +
" python /_zeppelin_tmp/" + pythonScript.getName());
restartPythonProcess();
out.clear();
return new InterpreterResult(InterpreterResult.Code.SUCCESS, "\"" + image + "\" activated");
} else if (deactivateMatcher.matches()) {
setPythonCommand(null);
restartPythonProcess();
return new InterpreterResult(InterpreterResult.Code.SUCCESS, "Deactivated");
} else {
return new InterpreterResult(InterpreterResult.Code.ERROR, "Not supported command: " + st);
}
}
public void setPythonCommand(String cmd) {
PythonInterpreter python = getPythonInterpreter();
python.setPythonCommand(cmd);
}
private void printUsage(InterpreterOutput out) {
try {
out.setType(InterpreterResult.Type.HTML);
out.writeResource("output_templates/docker_usage.html");
} catch (IOException e) {
logger.error("Can't print usage", e);
}
}
@Override
public void cancel(InterpreterContext context) {
}
@Override
public FormType getFormType() {
return FormType.NONE;
}
@Override
public int getProgress(InterpreterContext context) {
return 0;
}
/**
* Use python interpreter's scheduler.
* To make sure %python.docker paragraph and %python paragraph runs sequentially
*/
@Override
public Scheduler getScheduler() {
PythonInterpreter pythonInterpreter = getPythonInterpreter();
if (pythonInterpreter != null) {
return pythonInterpreter.getScheduler();
} else {
return null;
}
}
private void restartPythonProcess() {
PythonInterpreter python = getPythonInterpreter();
python.close();
python.open();
}
protected PythonInterpreter getPythonInterpreter() {
LazyOpenInterpreter lazy = null;
PythonInterpreter python = null;
Interpreter p = getInterpreterInTheSameSessionByClassName(PythonInterpreter.class.getName());
while (p instanceof WrappedInterpreter) {
if (p instanceof LazyOpenInterpreter) {
lazy = (LazyOpenInterpreter) p;
}
p = ((WrappedInterpreter) p).getInnerInterpreter();
}
python = (PythonInterpreter) p;
if (lazy != null) {
lazy.open();
}
return python;
}
public boolean pull(InterpreterOutput out, String image) {
int exit = 0;
try {
exit = runCommand(out, "docker", "pull", image);
} catch (IOException | InterruptedException e) {
logger.error(e.getMessage(), e);
throw new InterpreterException(e);
}
return exit == 0;
}
protected int runCommand(InterpreterOutput out, String... command)
throws IOException, InterruptedException {
ProcessBuilder builder = new ProcessBuilder(command);
builder.redirectErrorStream(true);
Process process = builder.start();
InputStream stdout = process.getInputStream();
BufferedReader br = new BufferedReader(new InputStreamReader(stdout));
String line;
while ((line = br.readLine()) != null) {
out.write(line + "\n");
}
int r = process.waitFor(); // Let the process finish.
return r;
}
}

View file

@ -27,9 +27,7 @@ import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.net.ServerSocket;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.*;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collection;
@ -59,6 +57,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import py4j.GatewayServer;
import py4j.commands.Command;
/**
* Python interpreter for Zeppelin.
@ -78,7 +77,7 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
private String py4jLibPath;
private String pythonLibPath;
private String pythonCommand = DEFAULT_ZEPPELIN_PYTHON;
private String pythonCommand;
private GatewayServer gatewayServer;
private DefaultExecutor executor;
@ -95,11 +94,10 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
Integer statementSetNotifier = new Integer(0);
public PythonInterpreter(Properties property) {
super(property);
try {
File scriptFile = File.createTempFile("zeppelin_python-", ".py");
File scriptFile = File.createTempFile("zeppelin_python-", ".py", new File("/tmp"));
scriptPath = scriptFile.getAbsolutePath();
} catch (IOException e) {
throw new InterpreterException(e);
@ -128,6 +126,10 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
logger.info("File {} created", scriptPath);
}
public String getScriptPath() {
return scriptPath;
}
private void copyFile(File out, String sourceFile) {
ClassLoader classLoader = getClass().getClassLoader();
try {
@ -141,7 +143,7 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
}
}
private void createGatewayServerAndStartScript() {
private void createGatewayServerAndStartScript() throws UnknownHostException {
createPythonScript();
if (System.getenv("ZEPPELIN_HOME") != null) {
py4jLibPath = System.getenv("ZEPPELIN_HOME") + File.separator + ZEPPELIN_PY4JPATH;
@ -153,13 +155,28 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
}
port = findRandomOpenPortOnAllLocalInterfaces();
gatewayServer = new GatewayServer(this, port);
gatewayServer = new GatewayServer(this,
port,
GatewayServer.DEFAULT_PYTHON_PORT,
InetAddress.getByName("0.0.0.0"),
InetAddress.getByName("0.0.0.0"),
GatewayServer.DEFAULT_CONNECT_TIMEOUT,
GatewayServer.DEFAULT_READ_TIMEOUT,
(List) null);
gatewayServer.start();
// Run python shell
CommandLine cmd = CommandLine.parse(getPythonCommand());
cmd.addArgument(scriptPath, false);
String pythonCmd = getPythonCommand();
CommandLine cmd = CommandLine.parse(pythonCmd);
if (!pythonCmd.endsWith(".py")) {
// PythonDockerInterpreter set pythoncmd with script
cmd.addArgument(getScriptPath(), false);
}
cmd.addArgument(Integer.toString(port), false);
cmd.addArgument(getLocalIp(), false);
executor = new DefaultExecutor();
outputStream = new InterpreterOutputStream(logger);
PipedOutputStream ps = new PipedOutputStream();
@ -185,6 +202,7 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
py4jLibPath + File.pathSeparator + pythonLibPath);
}
logger.info("cmd = {}", cmd.toString());
executor.execute(cmd, env, this);
pythonscriptRunning = true;
} catch (IOException e) {
@ -207,7 +225,11 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
registerHook(HookType.POST_EXEC_DEV, "z._displayhook()");
}
// Add matplotlib display hook
createGatewayServerAndStartScript();
try {
createGatewayServerAndStartScript();
} catch (UnknownHostException e) {
throw new InterpreterException(e);
}
}
@Override
@ -244,25 +266,18 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
*/
public class PythonInterpretRequest {
public String statements;
public String jobGroup;
public PythonInterpretRequest(String statements, String jobGroup) {
public PythonInterpretRequest(String statements) {
this.statements = statements;
this.jobGroup = jobGroup;
}
public String statements() {
return statements;
}
public String jobGroup() {
return jobGroup;
}
}
public PythonInterpretRequest getStatements() {
synchronized (statementSetNotifier) {
while (pythonInterpretRequest == null && pythonscriptRunning && pythonScriptInitialized) {
try {
statementSetNotifier.wait(1000);
@ -350,7 +365,7 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
return new InterpreterResult(Code.ERROR, errorMessage);
}
pythonInterpretRequest = new PythonInterpretRequest(cmd, null);
pythonInterpretRequest = new PythonInterpretRequest(cmd);
statementOutput = null;
synchronized (statementSetNotifier) {
@ -420,16 +435,17 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
return null;
}
public void setPythonPath(String pythonPath) {
this.pythonPath = pythonPath;
}
public void setPythonCommand(String cmd) {
logger.info("Set Python Command : {}", cmd);
pythonCommand = cmd;
}
public String getPythonCommand() {
return pythonCommand;
if (pythonCommand == null) {
return DEFAULT_ZEPPELIN_PYTHON;
} else {
return pythonCommand;
}
}
private Job getRunningJob(String paragraphId) {
@ -462,8 +478,14 @@ public class PythonInterpreter extends Interpreter implements ExecuteResultHandl
return context.getGui();
}
public Integer getPy4jPort() {
return port;
String getLocalIp() {
try {
return Inet4Address.getLocalHost().getHostAddress();
} catch (UnknownHostException e) {
logger.error("can't get local IP", e);
}
// fall back to loopback addreess
return "127.0.0.1";
}
private int findRandomOpenPortOnAllLocalInterfaces() {

View file

@ -18,7 +18,7 @@
import os, sys, getopt, traceback, json, re
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
from py4j.protocol import Py4JJavaError
from py4j.protocol import Py4JJavaError, Py4JNetworkError
import warnings
import ast
import traceback
@ -175,11 +175,11 @@ def handler_stop_signals(sig, frame):
signal.signal(signal.SIGINT, handler_stop_signals)
output = Logger()
sys.stdout = output
sys.stderr = output
host = "127.0.0.1"
if len(sys.argv) >= 3:
host = sys.argv[2]
client = GatewayClient(port=int(sys.argv[1]))
client = GatewayClient(address=host, port=int(sys.argv[1]))
#gateway = JavaGateway(client, auto_convert = True)
gateway = JavaGateway(client)
@ -190,11 +190,17 @@ intp.onPythonScriptInitialized(os.getpid())
z = PyZeppelinContext()
z._setup_matplotlib()
output = Logger()
sys.stdout = output
#sys.stderr = output
while True :
req = intp.getStatements()
if req == None:
break
try:
stmts = req.statements().split("\n")
jobGroup = req.jobGroup()
final_code = []
# Get post-execute hooks
@ -227,7 +233,6 @@ while True :
if final_code:
# use exec mode to compile the statements except the last statement,
# so that the last statement's evaluation will be printed to stdout
#sc.setJobGroup(jobGroup, "Zeppelin")
code = compile('\n'.join(final_code), '<stdin>', 'exec', ast.PyCF_ONLY_AST, 1)
to_run_hooks = []
@ -262,6 +267,9 @@ while True :
if innerErrorStart > -1:
excInnerError = excInnerError[innerErrorStart:]
intp.setStatementsFinished(excInnerError + str(sys.exc_info()), True)
except Py4JNetworkError:
# lost connection from gateway server. exit
sys.exit(1)
except:
intp.setStatementsFinished(traceback.format_exc(), True)

View file

@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zeppelin.python;
import org.apache.zeppelin.display.GUI;
import org.apache.zeppelin.interpreter.*;
import org.apache.zeppelin.user.AuthenticationInfo;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.net.Inet4Address;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Properties;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.*;
public class PythonDockerInterpreterTest {
private PythonDockerInterpreter docker;
private PythonInterpreter python;
@Before
public void setUp() {
docker = spy(new PythonDockerInterpreter(new Properties()));
python = mock(PythonInterpreter.class);
InterpreterGroup group = new InterpreterGroup();
group.put("note", Arrays.asList(python, docker));
python.setInterpreterGroup(group);
docker.setInterpreterGroup(group);
doReturn(true).when(docker).pull(any(InterpreterOutput.class), anyString());
doReturn(python).when(docker).getPythonInterpreter();
}
@Test
public void testActivateEnv() {
InterpreterContext context = getInterpreterContext();
docker.interpret("activate env", context);
verify(python, times(1)).open();
verify(python, times(1)).close();
verify(docker, times(1)).pull(any(InterpreterOutput.class), anyString());
verify(python).setPythonCommand("docker run -i --rm env python -iu");
}
@Test
public void testDeactivate() {
InterpreterContext context = getInterpreterContext();
docker.interpret("deactivate", context);
verify(python, times(1)).open();
verify(python, times(1)).close();
verify(python).setPythonCommand(null);
}
private InterpreterContext getInterpreterContext() {
return new InterpreterContext(
"noteId",
"paragraphId",
"replName",
"paragraphTitle",
"paragraphText",
new AuthenticationInfo(),
new HashMap<String, Object>(),
new GUI(),
null,
null,
null,
new InterpreterOutput(null));
}
}