Implemented user-defined hook registry system for spark/pyspark interpreters

This commit is contained in:
Alex Goodman 2016-09-28 17:28:50 -07:00
parent 8fad936744
commit 07cac65e99
6 changed files with 177 additions and 6 deletions

View file

@ -49,6 +49,7 @@ import org.apache.spark.ui.jobs.JobProgressListener;
import org.apache.zeppelin.interpreter.Interpreter;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterException;
import org.apache.zeppelin.interpreter.InterpreterHookRegistry;
import org.apache.zeppelin.interpreter.InterpreterProperty;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterResult.Code;
@ -101,6 +102,7 @@ public class SparkInterpreter extends Interpreter {
private SparkConf conf;
private static SparkContext sc;
private static SQLContext sqlc;
private static InterpreterHookRegistry hooks;
private static SparkEnv env;
private static Object sparkSession; // spark 2.x
private static JobProgressListener sparkListener;
@ -778,8 +780,10 @@ public class SparkInterpreter extends Interpreter {
sqlc = getSQLContext();
dep = getDependencyResolver();
hooks = getInterpreterGroup().getInterpreterHookRegistry();
z = new ZeppelinContext(sc, sqlc, null, dep,
z = new ZeppelinContext(sc, sqlc, null, dep, hooks,
Integer.parseInt(getProperty("zeppelin.spark.maxResult")));
interpret("@transient val _binder = new java.util.HashMap[String, Object]()");

View file

@ -28,11 +28,14 @@ import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.zeppelin.annotation.ZeppelinApi;
import org.apache.zeppelin.annotation.Experimental;
import org.apache.zeppelin.display.AngularObject;
import org.apache.zeppelin.display.AngularObjectRegistry;
import org.apache.zeppelin.display.AngularObjectWatcher;
@ -41,6 +44,7 @@ import org.apache.zeppelin.display.Input.ParamOption;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterContextRunner;
import org.apache.zeppelin.interpreter.InterpreterException;
import org.apache.zeppelin.interpreter.InterpreterHookRegistry;
import org.apache.zeppelin.spark.dep.SparkDependencyResolver;
import org.apache.zeppelin.resource.Resource;
import org.apache.zeppelin.resource.ResourcePool;
@ -53,19 +57,33 @@ import scala.Unit;
* Spark context for zeppelin.
*/
public class ZeppelinContext {
// Map interpreter class name (to be used by hook registry) from
// given replName in parapgraph
private static final Map<String, String> interpreterClassMap;
static {
interpreterClassMap = new HashMap<String, String>();
interpreterClassMap.put("spark", "org.apache.zeppelin.spark.SparkInterpreter");
interpreterClassMap.put("sql", "org.apache.zeppelin.spark.SparkSqlInterpreter");
interpreterClassMap.put("dep", "org.apache.zeppelin.spark.DepInterpreter");
interpreterClassMap.put("pyspark", "org.apache.zeppelin.spark.PySparkInterpreter");
}
private SparkDependencyResolver dep;
private InterpreterContext interpreterContext;
private int maxResult;
private List<Class> supportedClasses;
private InterpreterHookRegistry hooks;
public ZeppelinContext(SparkContext sc, SQLContext sql,
InterpreterContext interpreterContext,
SparkDependencyResolver dep,
InterpreterHookRegistry hooks,
int maxResult) {
this.sc = sc;
this.sqlContext = sql;
this.interpreterContext = interpreterContext;
this.dep = dep;
this.hooks = hooks;
this.maxResult = maxResult;
this.supportedClasses = new ArrayList<>();
try {
@ -697,6 +715,84 @@ public class ZeppelinContext {
registry.remove(name, noteId, null);
}
/**
* Get the interpreter class name from repl name entered in paragraph
* @param replName
*/
public String getClassNameFromReplName(String replName) {
if (replName.contains("spark.")) {
replName = replName.replace("spark.", "");
}
return interpreterClassMap.get(replName);
}
/**
* General function to register hook event
* @param event The type of event to hook to (pre_exec, post_exec)
* @param cmd The code to be executed by the interpreter on given event
* @param replName Name of the interpreter
*/
@Experimental
public void registerHook(String event, String cmd, String replName) {
String noteId = interpreterContext.getNoteId();
String className = getClassNameFromReplName(replName);
hooks.register(noteId, className, event, cmd);
}
/**
* registerHook() wrapper for current repl
* @param event The type of event to hook to (pre_exec, post_exec)
* @param cmd The code to be executed by the interpreter on given event
*/
@Experimental
public void registerHook(String event, String cmd) {
String replName = interpreterContext.getRequiredReplName();
registerHook(event, cmd, replName);
}
/**
* Get the hook code
* @param event The type of event to hook to (pre_exec, post_exec)
* @param replName Name of the interpreter
*/
@Experimental
public String getHook(String event, String replName) {
String noteId = interpreterContext.getNoteId();
String className = getClassNameFromReplName(replName);
return hooks.get(noteId, className, event);
}
/**
* getHook() wrapper for current repl
* @param event The type of event to hook to (pre_exec, post_exec)
*/
@Experimental
public String getHook(String event) {
String replName = interpreterContext.getRequiredReplName();
return getHook(event, replName);
}
/**
* Unbind code from given hook event
* @param event The type of event to hook to (pre_exec, post_exec)
* @param replName Name of the interpreter
*/
@Experimental
public void unregisterHook(String event, String replName) {
String noteId = interpreterContext.getNoteId();
String className = getClassNameFromReplName(replName);
hooks.unregister(noteId, className, event);
}
/**
* unregisterHook() wrapper for current repl
* @param event The type of event to hook to (pre_exec, post_exec)
*/
@Experimental
public void unregisterHook(String event) {
String replName = interpreterContext.getRequiredReplName();
unregisterHook(event, replName);
}
/**
* Add object into resource pool

View file

@ -80,16 +80,16 @@ class PyZeppelinContext(dict):
def get(self, key):
return self.__getitem__(key)
def input(self, name, defaultValue = ""):
def input(self, name, defaultValue=""):
return self.z.input(name, defaultValue)
def select(self, name, options, defaultValue = ""):
def select(self, name, options, defaultValue=""):
# auto_convert to ArrayList doesn't match the method signature on JVM side
tuples = list(map(lambda items: self.__tupleToScalaTuple2(items), options))
iterables = gateway.jvm.scala.collection.JavaConversions.collectionAsScalaIterable(tuples)
return self.z.select(name, defaultValue, iterables)
def checkbox(self, name, options, defaultChecked = None):
def checkbox(self, name, options, defaultChecked=None):
if defaultChecked is None:
defaultChecked = list(map(lambda items: items[0], options))
optionTuples = list(map(lambda items: self.__tupleToScalaTuple2(items), options))
@ -99,6 +99,23 @@ class PyZeppelinContext(dict):
checkedIterables = self.z.checkbox(name, defaultCheckedIterables, optionIterables)
return gateway.jvm.scala.collection.JavaConversions.asJavaCollection(checkedIterables)
def registerHook(self, event, cmd, replName=None):
if replName is None:
self.z.registerHook(event, cmd)
else:
self.z.registerHook(event, cmd, replName)
def unregisterHook(self, event, replName=None):
if replName is None:
self.z.unregisterHook(event)
else:
self.z.unregisterHook(event, replName)
def getHook(self, event, replName=None):
if replName is None:
return self.z.getHook(event)
return self.z.getHook(event, replName)
def __tupleToScalaTuple2(self, tuple):
if (len(tuple) == 2):
return gateway.jvm.scala.Tuple2(tuple[0], tuple[1])

View file

@ -124,4 +124,28 @@ public class InterpreterContext {
return runners;
}
public String getRequiredReplName() {
if (paragraphText == null) {
return null;
}
// get script head
int scriptHeadIndex = 0;
for (int i = 0; i < paragraphText.length(); i++) {
char ch = paragraphText.charAt(i);
if (Character.isWhitespace(ch) || ch == '(') {
scriptHeadIndex = i;
break;
}
}
if (scriptHeadIndex == 0) {
return null;
}
String head = paragraphText.substring(0, scriptHeadIndex);
if (head.startsWith("%")) {
return head.substring(1);
} else {
return null;
}
}
}

View file

@ -27,7 +27,7 @@ public interface InterpreterHookListener {
public void onPreExecute(String script);
/**
* Prepends pre-execute hook code to the script that will be interpreted
* Appends post-execute hook code to the script that will be interpreted
*/
public void onPostExecute(String script);
}

View file

@ -147,4 +147,34 @@ public class LazyOpenInterpreter
public void setClassloaderUrls(URL [] urls) {
intp.setClassloaderUrls(urls);
}
@Override
public void registerHook(String noteId, String event, String cmd) {
intp.registerHook(noteId, event, cmd);
}
@Override
public void registerHook(String event, String cmd) {
intp.registerHook(event, cmd);
}
@Override
public String getHook(String noteId, String event) {
return intp.getHook(noteId, event);
}
@Override
public String getHook(String event) {
return intp.getHook(event);
}
@Override
public void unregisterHook(String noteId, String event) {
intp.unregisterHook(noteId, event);
}
@Override
public void unregisterHook(String event) {
intp.unregisterHook(event);
}
}