[ZEPPELIN-4611]. Fetching rows with newline character (\n) breaks entire table

This commit is contained in:
Jeff Zhang 2020-02-13 15:33:34 +08:00
parent 2e6422c5df
commit d12eea9430
9 changed files with 138 additions and 21 deletions

View file

@ -34,6 +34,7 @@ import org.apache.hadoop.security.alias.CredentialProvider;
import org.apache.hadoop.security.alias.CredentialProviderFactory;
import org.apache.zeppelin.interpreter.BaseZeppelinContext;
import org.apache.zeppelin.interpreter.util.SqlSplitter;
import org.apache.zeppelin.tabledata.TableDataUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -566,9 +567,11 @@ public class JDBCInterpreter extends KerberosInterpreter {
msg.append(TAB);
}
if (StringUtils.isNotEmpty(md.getColumnLabel(i))) {
msg.append(removeTablePrefix(replaceReservedChars(md.getColumnLabel(i))));
msg.append(removeTablePrefix(replaceReservedChars(
TableDataUtils.normalizeColumn(md.getColumnLabel(i)))));
} else {
msg.append(removeTablePrefix(replaceReservedChars(md.getColumnName(i))));
msg.append(removeTablePrefix(replaceReservedChars(
TableDataUtils.normalizeColumn(md.getColumnName(i)))));
}
}
msg.append(NEWLINE);
@ -588,7 +591,7 @@ public class JDBCInterpreter extends KerberosInterpreter {
} else {
resultValue = resultSet.getString(i);
}
msg.append(replaceReservedChars(resultValue));
msg.append(replaceReservedChars(TableDataUtils.normalizeColumn(resultValue)));
if (i != md.getColumnCount()) {
msg.append(TAB);
}

View file

@ -185,7 +185,10 @@ class PyZeppelinContext(object):
self.show_dataframe(p, **kwargs)
else:
print(str(p))
def normalizeColumn(self, column):
return column.replace("\t", " ").replace("\r\n", " ").replace("\n", " ")
def show_dataframe(self, df, show_index=False, **kwargs):
"""Pretty prints DF using Table Display System
"""
@ -193,11 +196,11 @@ class PyZeppelinContext(object):
header_buf = StringIO("")
if show_index:
idx_name = str(df.index.name) if df.index.name is not None else ""
header_buf.write(idx_name + "\t")
header_buf.write(str(df.columns[0]))
header_buf.write(self.normalizeColumn(idx_name) + "\t")
header_buf.write(self.normalizeColumn(str(df.columns[0])))
for col in df.columns[1:]:
header_buf.write("\t")
header_buf.write(str(col))
header_buf.write(self.normalizeColumn(str(col)))
header_buf.write("\n")
body_buf = StringIO("")
@ -208,10 +211,10 @@ class PyZeppelinContext(object):
if show_index:
body_buf.write("%html <strong>{}</strong>".format(idx))
body_buf.write("\t")
body_buf.write(str(row[0]))
body_buf.write(self.normalizeColumn(str(row[0])))
for cell in row[1:]:
body_buf.write("\t")
body_buf.write(str(cell))
body_buf.write(self.normalizeColumn(str(cell)))
# don't print '\n' after the last row
if idx != (rowNumber - 1):
body_buf.write("\n")

View file

@ -301,13 +301,14 @@ public abstract class BasePythonInterpreterTest extends ConcurrentTestCase {
// Pandas DataFrame
context = getInterpreterContext();
result = interpreter.interpret("import pandas as pd\n" +
"df = pd.DataFrame({'id':[1,2,3], 'name':['a','b','c']})\nz.show(df)", context);
"df = pd.DataFrame({'id':[1,2,3], 'name':['a\ta','b\\nb','c\\r\\nc']})\nz.show(df)",
context);
assertEquals(context.out.toInterpreterResultMessage().toString(),
InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(1, interpreterResultMessages.size());
assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType());
assertEquals("id\tname\n1\ta\n2\tb\n3\tc\n", interpreterResultMessages.get(0).getData());
assertEquals("id\tname\n1\ta a\n2\tb b\n3\tc c\n", interpreterResultMessages.get(0).getData());
context = getInterpreterContext();
result = interpreter.interpret("import pandas as pd\n" +

View file

@ -89,14 +89,17 @@ public class SparkSqlInterpreterTest {
@Test
public void test() throws InterpreterException {
sparkInterpreter.interpret("case class Test(name:String, age:Int)", context);
sparkInterpreter.interpret("val test = sc.parallelize(Seq(Test(\"moon\", 33), Test(\"jobs\", 51), Test(\"gates\", 51), Test(\"park\", 34)))", context);
sparkInterpreter.interpret("test.toDF.registerTempTable(\"test\")", context);
InterpreterResult result = sparkInterpreter.interpret("case class Test(name:String, age:Int)", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
result = sparkInterpreter.interpret("val test = sc.parallelize(Seq(Test(\"moon\\t1\", 33), Test(\"jobs\", 51), Test(\"gates\", 51), Test(\"park\\n1\", 34)))", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
result = sparkInterpreter.interpret("test.toDF.registerTempTable(\"test\")", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
InterpreterResult ret = sqlInterpreter.interpret("select name, age from test where age < 40", context);
assertEquals(InterpreterResult.Code.SUCCESS, ret.code());
assertEquals(Type.TABLE, ret.message().get(0).getType());
assertEquals("name\tage\nmoon\t33\npark\t34\n", ret.message().get(0).getData());
assertEquals("name\tage\nmoon 1\t33\npark 1\t34\n", ret.message().get(0).getData());
ret = sqlInterpreter.interpret("select wrong syntax", context);
assertEquals(InterpreterResult.Code.ERROR, ret.code());

View file

@ -29,6 +29,7 @@ import org.apache.spark.sql.types.StructType;
import org.apache.spark.ui.jobs.JobProgressListener;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.ResultMessages;
import org.apache.zeppelin.tabledata.TableDataUtils;
import java.util.ArrayList;
import java.util.List;
@ -71,7 +72,7 @@ public class Spark1Shims extends SparkShims {
List<Row> rows = df.takeAsList(maxResult + 1);
StringBuilder msg = new StringBuilder();
msg.append("\n%table ");
msg.append(StringUtils.join(columns, "\t"));
msg.append(StringUtils.join(TableDataUtils.normalizeColumns(columns), "\t"));
msg.append("\n");
boolean isLargerThanMaxResult = rows.size() > maxResult;
if (isLargerThanMaxResult) {
@ -79,7 +80,7 @@ public class Spark1Shims extends SparkShims {
}
for (Row row : rows) {
for (int i = 0; i < row.size(); ++i) {
msg.append(row.get(i));
msg.append(TableDataUtils.normalizeColumn(row.get(i)));
if (i != row.size() - 1) {
msg.append("\t");
}

View file

@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.types.StructType;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.ResultMessages;
import org.apache.zeppelin.tabledata.TableDataUtils;
import java.util.ArrayList;
import java.util.List;
@ -72,7 +73,7 @@ public class Spark2Shims extends SparkShims {
List<Row> rows = df.takeAsList(maxResult + 1);
StringBuilder msg = new StringBuilder();
msg.append("\n%table ");
msg.append(StringUtils.join(columns, "\t"));
msg.append(StringUtils.join(TableDataUtils.normalizeColumns(columns), "\t"));
msg.append("\n");
boolean isLargerThanMaxResult = rows.size() > maxResult;
if (isLargerThanMaxResult) {
@ -80,7 +81,7 @@ public class Spark2Shims extends SparkShims {
}
for (Row row : rows) {
for (int i = 0; i < row.size(); ++i) {
msg.append(row.get(i));
msg.append(TableDataUtils.normalizeColumn(row.get(i)));
if (i != row.size() -1) {
msg.append("\t");
}

View file

@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.types.StructType;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.ResultMessages;
import org.apache.zeppelin.tabledata.TableDataUtils;
import java.util.ArrayList;
import java.util.List;
@ -72,7 +73,7 @@ public class Spark3Shims extends SparkShims {
List<Row> rows = df.takeAsList(maxResult + 1);
StringBuilder msg = new StringBuilder();
msg.append("%table ");
msg.append(StringUtils.join(columns, "\t"));
msg.append(StringUtils.join(TableDataUtils.normalizeColumns(columns), "\t"));
msg.append("\n");
boolean isLargerThanMaxResult = rows.size() > maxResult;
if (isLargerThanMaxResult) {
@ -80,7 +81,7 @@ public class Spark3Shims extends SparkShims {
}
for (Row row : rows) {
for (int i = 0; i < row.size(); ++i) {
msg.append(row.get(i));
msg.append(TableDataUtils.normalizeColumn(row.get(i)));
if (i != row.size() -1) {
msg.append("\t");
}

View file

@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zeppelin.tabledata;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
public class TableDataUtils {
/**
* Replace '\t','\r\n','\n' which represent field delimiter and row delimiter with while space.
* @param column
* @column
*/
public static String normalizeColumn(String column) {
if (column == null) {
return "null";
}
return column.replace("\t", " ").replace("\r\n", " ").replace("\n", " ");
}
/**
* Convert obj to String first, convert it to empty string it is null.
* @param obj
* @column
*/
public static String normalizeColumn(Object obj) {
return normalizeColumn(obj == null ? "null" : obj.toString());
}
public static List<String> normalizeColumns(List<Object> columns) {
return columns.stream()
.map(TableDataUtils::normalizeColumn)
.collect(Collectors.toList());
}
public static List<String> normalizeColumns(Object[] columns) {
return Arrays.stream(columns)
.map(TableDataUtils::normalizeColumn)
.collect(Collectors.toList());
}
}

View file

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zeppelin.tabledata;
import com.google.common.collect.Lists;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
public class TableDataUtilsTest {
@Test
public void testColumn() {
assertEquals("hello world", TableDataUtils.normalizeColumn("hello\tworld"));
assertEquals("hello world", TableDataUtils.normalizeColumn("hello\nworld"));
assertEquals("hello world", TableDataUtils.normalizeColumn("hello\r\nworld"));
assertEquals("hello world", TableDataUtils.normalizeColumn("hello\t\nworld"));
assertEquals("null", TableDataUtils.normalizeColumn(null));
}
@Test
public void testColumns() {
assertEquals(Lists.newArrayList("hello world", "hello world"),
TableDataUtils.normalizeColumns(new Object[]{"hello\tworld", "hello\nworld"}));
assertEquals(Lists.newArrayList("hello world", "null"),
TableDataUtils.normalizeColumns(new String[]{"hello\tworld", null}));
}
}