Use reflection not to use import org.apache.spark.scheduler.Stage

This commit is contained in:
Lee moon soo 2015-08-21 18:36:44 -07:00
parent c3d96c18ff
commit 9e812e7ef9
2 changed files with 46 additions and 161 deletions

View file

@ -29,6 +29,7 @@ import java.net.URLClassLoader;
import java.util.*;
import com.google.common.base.Joiner;
import org.apache.spark.HttpServer;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
@ -40,7 +41,6 @@ import org.apache.spark.repl.SparkJLineCompletion;
import org.apache.spark.scheduler.ActiveJob;
import org.apache.spark.scheduler.DAGScheduler;
import org.apache.spark.scheduler.Pool;
import org.apache.spark.scheduler.Stage;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.ui.jobs.JobProgressListener;
import org.apache.zeppelin.interpreter.Interpreter;
@ -67,6 +67,7 @@ import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.JavaConversions;
import scala.collection.JavaConverters;
import scala.collection.Seq;
import scala.collection.mutable.HashMap;
import scala.collection.mutable.HashSet;
import scala.tools.nsc.Settings;
@ -671,18 +672,26 @@ public class SparkInterpreter extends Interpreter {
if (jobGroup.equals(g)) {
int[] progressInfo = null;
if (sc.version().startsWith("1.0")) {
progressInfo = getProgressFromStage_1_0x(sparkListener, job.finalStage());
} else if (sc.version().startsWith("1.1")) {
progressInfo = getProgressFromStage_1_1x(sparkListener, job.finalStage());
} else if (sc.version().startsWith("1.2")) {
progressInfo = getProgressFromStage_1_1x(sparkListener, job.finalStage());
} else if (sc.version().startsWith("1.3")) {
progressInfo = getProgressFromStage_1_1x(sparkListener, job.finalStage());
} else if (sc.version().startsWith("1.4")) {
progressInfo = getProgressFromStage_1_1x(sparkListener, job.finalStage());
} else {
continue;
try {
Object finalStage = job.getClass().getMethod("finalStage").invoke(job);
if (sc.version().startsWith("1.0")) {
progressInfo = getProgressFromStage_1_0x(sparkListener, finalStage);
} else if (sc.version().startsWith("1.1")) {
progressInfo = getProgressFromStage_1_1x(sparkListener, finalStage);
} else if (sc.version().startsWith("1.2")) {
progressInfo = getProgressFromStage_1_1x(sparkListener, finalStage);
} else if (sc.version().startsWith("1.3")) {
progressInfo = getProgressFromStage_1_1x(sparkListener, finalStage);
} else if (sc.version().startsWith("1.4")) {
progressInfo = getProgressFromStage_1_1x(sparkListener, finalStage);
} else {
continue;
}
} catch (IllegalAccessException | IllegalArgumentException
| InvocationTargetException | NoSuchMethodException
| SecurityException e) {
logger.error("Can't get progress info", e);
return 0;
}
totalTasks += progressInfo[0];
completedTasks += progressInfo[1];
@ -695,33 +704,27 @@ public class SparkInterpreter extends Interpreter {
return completedTasks * 100 / totalTasks;
}
private int[] getProgressFromStage_1_0x(JobProgressListener sparkListener, Stage stage) {
int numTasks = stage.numTasks();
private int[] getProgressFromStage_1_0x(JobProgressListener sparkListener, Object stage)
throws IllegalAccessException, IllegalArgumentException,
InvocationTargetException, NoSuchMethodException, SecurityException {
int numTasks = (int) stage.getClass().getMethod("numTasks").invoke(stage);
int completedTasks = 0;
Method method;
int id = (int) stage.getClass().getMethod("id").invoke(stage);
Object completedTaskInfo = null;
try {
method = sparkListener.getClass().getMethod("stageIdToTasksComplete");
completedTaskInfo =
JavaConversions.asJavaMap((HashMap<Object, Object>) method.invoke(sparkListener)).get(
stage.id());
} catch (NoSuchMethodException | SecurityException e) {
logger.error("Error while getting progress", e);
} catch (IllegalAccessException e) {
logger.error("Error while getting progress", e);
} catch (IllegalArgumentException e) {
logger.error("Error while getting progress", e);
} catch (InvocationTargetException e) {
logger.error("Error while getting progress", e);
}
completedTaskInfo = JavaConversions.asJavaMap(
(HashMap<Object, Object>) sparkListener.getClass()
.getMethod("stageIdToTasksComplete").invoke(sparkListener)).get(id);
if (completedTaskInfo != null) {
completedTasks += (int) completedTaskInfo;
}
List<Stage> parents = JavaConversions.asJavaList(stage.parents());
List<Object> parents = JavaConversions.asJavaList((Seq<Object>) stage.getClass()
.getMethod("parents").invoke(stage));
if (parents != null) {
for (Stage s : parents) {
for (Object s : parents) {
int[] p = getProgressFromStage_1_0x(sparkListener, s);
numTasks += p[0];
completedTasks += p[1];
@ -731,9 +734,12 @@ public class SparkInterpreter extends Interpreter {
return new int[] {numTasks, completedTasks};
}
private int[] getProgressFromStage_1_1x(JobProgressListener sparkListener, Stage stage) {
int numTasks = stage.numTasks();
private int[] getProgressFromStage_1_1x(JobProgressListener sparkListener, Object stage)
throws IllegalAccessException, IllegalArgumentException,
InvocationTargetException, NoSuchMethodException, SecurityException {
int numTasks = (int) stage.getClass().getMethod("numTasks").invoke(stage);
int completedTasks = 0;
int id = (int) stage.getClass().getMethod("id").invoke(stage);
try {
Method stageIdToData = sparkListener.getClass().getMethod("stageIdToData");
@ -747,7 +753,7 @@ public class SparkInterpreter extends Interpreter {
Set<Tuple2<Object, Object>> keys =
JavaConverters.asJavaSetConverter(stageIdData.keySet()).asJava();
for (Tuple2<Object, Object> k : keys) {
if (stage.id() == (int) k._1()) {
if (id == (int) k._1()) {
Object uiData = stageIdData.get(k).get();
completedTasks += (int) numCompletedTasks.invoke(uiData);
}
@ -756,9 +762,10 @@ public class SparkInterpreter extends Interpreter {
logger.error("Error on getting progress information", e);
}
List<Stage> parents = JavaConversions.asJavaList(stage.parents());
List<Object> parents = JavaConversions.asJavaList((Seq<Object>) stage.getClass()
.getMethod("parents").invoke(stage));
if (parents != null) {
for (Stage s : parents) {
for (Object s : parents) {
int[] p = getProgressFromStage_1_1x(sparkListener, s);
numTasks += p[0];
completedTasks += p[1];

View file

@ -17,25 +17,17 @@
package org.apache.zeppelin.spark;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.spark.SparkContext;
import org.apache.spark.scheduler.ActiveJob;
import org.apache.spark.scheduler.DAGScheduler;
import org.apache.spark.scheduler.Stage;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.ui.jobs.JobProgressListener;
import org.apache.zeppelin.interpreter.Interpreter;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterException;
import org.apache.zeppelin.interpreter.InterpreterPropertyBuilder;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterUtils;
import org.apache.zeppelin.interpreter.InterpreterResult.Code;
import org.apache.zeppelin.interpreter.LazyOpenInterpreter;
import org.apache.zeppelin.interpreter.WrappedInterpreter;
@ -44,13 +36,6 @@ import org.apache.zeppelin.scheduler.SchedulerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.JavaConversions;
import scala.collection.JavaConverters;
import scala.collection.mutable.HashMap;
import scala.collection.mutable.HashSet;
/**
* Spark SQL interpreter for Zeppelin.
*
@ -151,117 +136,10 @@ public class SparkSqlInterpreter extends Interpreter {
@Override
public int getProgress(InterpreterContext context) {
String jobGroup = getJobGroup(context);
SQLContext sqlc = getSparkInterpreter().getSQLContext();
SparkContext sc = sqlc.sparkContext();
JobProgressListener sparkListener = getSparkInterpreter().getJobProgressListener();
int completedTasks = 0;
int totalTasks = 0;
DAGScheduler scheduler = sc.dagScheduler();
HashSet<ActiveJob> jobs = scheduler.activeJobs();
Iterator<ActiveJob> it = jobs.iterator();
while (it.hasNext()) {
ActiveJob job = it.next();
String g = (String) job.properties().get("spark.jobGroup.id");
if (jobGroup.equals(g)) {
int[] progressInfo = null;
if (sc.version().startsWith("1.0")) {
progressInfo = getProgressFromStage_1_0x(sparkListener, job.finalStage());
} else if (sc.version().startsWith("1.1")) {
progressInfo = getProgressFromStage_1_1x(sparkListener, job.finalStage());
} else if (sc.version().startsWith("1.2")) {
progressInfo = getProgressFromStage_1_1x(sparkListener, job.finalStage());
} else if (sc.version().startsWith("1.3")) {
progressInfo = getProgressFromStage_1_1x(sparkListener, job.finalStage());
} else if (sc.version().startsWith("1.4")) {
progressInfo = getProgressFromStage_1_1x(sparkListener, job.finalStage());
} else {
logger.warn("Spark {} getting progress information not supported" + sc.version());
continue;
}
totalTasks += progressInfo[0];
completedTasks += progressInfo[1];
}
}
if (totalTasks == 0) {
return 0;
}
return completedTasks * 100 / totalTasks;
SparkInterpreter sparkInterpreter = getSparkInterpreter();
return sparkInterpreter.getProgress(context);
}
private int[] getProgressFromStage_1_0x(JobProgressListener sparkListener, Stage stage) {
int numTasks = stage.numTasks();
int completedTasks = 0;
Method method;
Object completedTaskInfo = null;
try {
method = sparkListener.getClass().getMethod("stageIdToTasksComplete");
completedTaskInfo =
JavaConversions.asJavaMap((HashMap<Object, Object>) method.invoke(sparkListener)).get(
stage.id());
} catch (NoSuchMethodException | SecurityException e) {
logger.error("Error while getting progress", e);
} catch (IllegalAccessException e) {
logger.error("Error while getting progress", e);
} catch (IllegalArgumentException e) {
logger.error("Error while getting progress", e);
} catch (InvocationTargetException e) {
logger.error("Error while getting progress", e);
}
if (completedTaskInfo != null) {
completedTasks += (int) completedTaskInfo;
}
List<Stage> parents = JavaConversions.asJavaList(stage.parents());
if (parents != null) {
for (Stage s : parents) {
int[] p = getProgressFromStage_1_0x(sparkListener, s);
numTasks += p[0];
completedTasks += p[1];
}
}
return new int[] {numTasks, completedTasks};
}
private int[] getProgressFromStage_1_1x(JobProgressListener sparkListener, Stage stage) {
int numTasks = stage.numTasks();
int completedTasks = 0;
try {
Method stageIdToData = sparkListener.getClass().getMethod("stageIdToData");
HashMap<Tuple2<Object, Object>, Object> stageIdData =
(HashMap<Tuple2<Object, Object>, Object>) stageIdToData.invoke(sparkListener);
Class<?> stageUIDataClass =
this.getClass().forName("org.apache.spark.ui.jobs.UIData$StageUIData");
Method numCompletedTasks = stageUIDataClass.getMethod("numCompleteTasks");
Set<Tuple2<Object, Object>> keys =
JavaConverters.asJavaSetConverter(stageIdData.keySet()).asJava();
for (Tuple2<Object, Object> k : keys) {
if (stage.id() == (int) k._1()) {
Object uiData = stageIdData.get(k).get();
completedTasks += (int) numCompletedTasks.invoke(uiData);
}
}
} catch (Exception e) {
logger.error("Error on getting progress information", e);
}
List<Stage> parents = JavaConversions.asJavaList(stage.parents());
if (parents != null) {
for (Stage s : parents) {
int[] p = getProgressFromStage_1_1x(sparkListener, s);
numTasks += p[0];
completedTasks += p[1];
}
}
return new int[] {numTasks, completedTasks};
}
@Override
public Scheduler getScheduler() {