fix unit test

This commit is contained in:
Jeff Zhang 2021-08-07 20:53:15 +08:00
parent 46c4f08059
commit df963616a3

View file

@ -25,6 +25,7 @@ import org.apache.zeppelin.interpreter.InterpreterGroup;
import org.apache.zeppelin.interpreter.InterpreterOutput;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterResult.Type;
import org.apache.zeppelin.interpreter.remote.RemoteInterpreterEventClient;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@ -41,6 +42,7 @@ import java.util.Properties;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
/**
* In order for this test to work, test env must have installed:
@ -77,7 +79,6 @@ public class PythonInterpreterPandasSqlTest {
private IPythonInterpreter ipythonInterpreter;
private InterpreterContext context;
private InterpreterOutput out;
public PythonInterpreterPandasSqlTest(boolean useIPython) {
this.useIPython = useIPython;
@ -94,10 +95,7 @@ public class PythonInterpreterPandasSqlTest {
intpGroup = new InterpreterGroup();
out = new InterpreterOutput();
context = InterpreterContext.builder()
.setInterpreterOut(out)
.build();
context = getInterpreterContext();
InterpreterContext.set(context);
pythonInterpreter = new PythonInterpreter(p);
@ -135,12 +133,17 @@ public class PythonInterpreterPandasSqlTest {
@Test
public void errorMessageIfDependenciesNotInstalled() throws InterpreterException {
InterpreterResult ret;
ret = pandasSqlInterpreter.interpret("SELECT * from something", context);
context = getInterpreterContext();
InterpreterResult ret = pandasSqlInterpreter.interpret("SELECT * from something", context);
assertNotNull(ret);
assertEquals(ret.message().get(0).getData(), InterpreterResult.Code.ERROR, ret.code());
assertTrue(ret.message().get(0).getData().contains("no such table: something"));
assertEquals(context.out.toString(), InterpreterResult.Code.ERROR, ret.code());
if (useIPython) {
assertTrue(context.out.toString(),
context.out.toString().contains("no such table: something"));
} else {
assertTrue(ret.toString(), ret.toString().contains("no such table: something"));
}
}
@Test
@ -154,15 +157,15 @@ public class PythonInterpreterPandasSqlTest {
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
//when
context = getInterpreterContext();
ret = pandasSqlInterpreter.interpret("select name, age from df2 where age < 40", context);
//then
assertEquals(new String(out.getOutputAt(1).toByteArray()),
InterpreterResult.Code.SUCCESS, ret.code());
assertEquals(new String(out.getOutputAt(1).toByteArray()), Type.TABLE,
out.getOutputAt(1).getType());
assertTrue(new String(out.getOutputAt(1).toByteArray()).indexOf("moon\t33") > 0);
assertTrue(new String(out.getOutputAt(1).toByteArray()).indexOf("park\t34") > 0);
assertEquals(context.out.toString(), InterpreterResult.Code.SUCCESS, ret.code());
assertEquals(context.out.toString(), Type.TABLE,
context.out.toInterpreterResultMessage().get(0).getType());
assertTrue(context.out.toString().indexOf("moon\t33") > 0);
assertTrue(context.out.toString().indexOf("park\t34") > 0);
assertEquals(InterpreterResult.Code.SUCCESS,
pandasSqlInterpreter.interpret(
@ -184,12 +187,11 @@ public class PythonInterpreterPandasSqlTest {
ret = pandasSqlInterpreter.interpret("select name, age from df2 where age < 40", context);
//then
assertEquals(new String(out.getOutputAt(1).toByteArray()),
InterpreterResult.Code.SUCCESS, ret.code());
assertEquals(new String(out.getOutputAt(1).toByteArray()), Type.TABLE,
out.getOutputAt(1).getType());
assertTrue(new String(out.getOutputAt(1).toByteArray()).indexOf("moon\t33") > 0);
assertTrue(new String(out.getOutputAt(1).toByteArray()).indexOf("park\t34") > 0);
assertEquals(context.out.toString(), InterpreterResult.Code.SUCCESS, ret.code());
assertEquals(context.out.toString(), Type.TABLE,
context.out.toInterpreterResultMessage().get(1).getType());
assertTrue(context.out.toString().indexOf("moon\t33") > 0);
assertTrue(context.out.toString().indexOf("park\t34") > 0);
assertEquals(InterpreterResult.Code.SUCCESS,
pandasSqlInterpreter.interpret(
@ -200,11 +202,12 @@ public class PythonInterpreterPandasSqlTest {
@Test
public void badSqlSyntaxFails() throws InterpreterException {
//when
context = getInterpreterContext();
InterpreterResult ret = pandasSqlInterpreter.interpret("select wrong syntax", context);
//then
assertNotNull("Interpreter returned 'null'", ret);
assertEquals(ret.toString(), InterpreterResult.Code.ERROR, ret.code());
assertEquals(context.out.toString(), InterpreterResult.Code.ERROR, ret.code());
}
@Test
@ -222,15 +225,24 @@ public class PythonInterpreterPandasSqlTest {
assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code());
// when
context = getInterpreterContext();
ret = pythonInterpreter.interpret("z.show(df1, show_index=True)", context);
// then
assertEquals(new String(out.getOutputAt(0).toByteArray()),
InterpreterResult.Code.SUCCESS, ret.code());
assertEquals(new String(out.getOutputAt(1).toByteArray()),
Type.TABLE, out.getOutputAt(1).getType());
assertTrue(new String(out.getOutputAt(1).toByteArray()).contains("index_name"));
assertTrue(new String(out.getOutputAt(1).toByteArray()).contains("nan"));
assertTrue(new String(out.getOutputAt(1).toByteArray()).contains("6.7"));
assertEquals(context.out.toString(), InterpreterResult.Code.SUCCESS, ret.code());
assertEquals(context.out.toString(), Type.TABLE,
context.out.toInterpreterResultMessage().get(0).getType());
assertTrue(context.out.toString().contains("index_name"));
assertTrue(context.out.toString().contains("nan"));
assertTrue(context.out.toString().contains("6.7"));
}
private InterpreterContext getInterpreterContext() {
return InterpreterContext.builder()
.setNoteId("noteId")
.setParagraphId("paragraphId")
.setInterpreterOut(new InterpreterOutput())
.setIntpEventClient(mock(RemoteInterpreterEventClient.class))
.build();
}
}