[ZEPPELIN-2297] refactoring search completions

This commit is contained in:
Tinkoff DWH 2017-03-24 14:35:08 +05:00
parent 7b5835d3e1
commit 2b58cc5dcf
4 changed files with 152 additions and 67 deletions

View file

@ -4,15 +4,6 @@ package org.apache.zeppelin.jdbc;
* This source file is based on code taken from SQLLine 1.0.2 See SQLLine notice in LICENSE
*/
import jline.console.completer.ArgumentCompleter.ArgumentList;
import jline.console.completer.ArgumentCompleter.WhitespaceArgumentDelimiter;
import org.apache.zeppelin.completer.CompletionType;
import org.apache.zeppelin.completer.StringsCompleter;
import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
@ -20,9 +11,27 @@ import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.StringTokenizer;
import java.util.TreeSet;
import java.util.regex.Pattern;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.math.NumberUtils;
import org.apache.zeppelin.completer.CompletionType;
import org.apache.zeppelin.completer.StringsCompleter;
import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jline.console.completer.ArgumentCompleter.ArgumentList;
import jline.console.completer.ArgumentCompleter.WhitespaceArgumentDelimiter;
import static org.apache.commons.lang.StringUtils.isBlank;
/**
@ -75,19 +84,35 @@ public class SqlCompleter {
// white spaces.
ArgumentList argumentList = sqlDelimiter.delimit(buffer, cursor);
String beforeCursorBuffer = buffer.substring(0,
Math.min(cursor, buffer.length())).toUpperCase();
Pattern whitespaceEndPatter = Pattern.compile("\\s$");
String cursorArgument = null;
int argumentPosition = 0;
if (buffer.length() == 0 || whitespaceEndPatter.matcher(buffer).find()) {
argumentPosition = buffer.length() - 1;
} else {
cursorArgument = argumentList.getCursorArgument();
argumentPosition = argumentList.getArgumentPosition();
}
// check what sql is and where cursor is to allow column completion or not
boolean isColumnAllowed = true;
if (beforeCursorBuffer.contains("SELECT ") && beforeCursorBuffer.contains(" FROM ")
&& !beforeCursorBuffer.contains(" WHERE "))
isColumnAllowed = false;
if (buffer.length() > 0) {
String beforeCursorBuffer = buffer.substring(0,
Math.min(cursor, buffer.length())).toUpperCase();
// check what sql is and where cursor is to allow column completion or not
if (beforeCursorBuffer.contains("SELECT ") && beforeCursorBuffer.contains(" FROM ")
&& !beforeCursorBuffer.contains(" WHERE "))
isColumnAllowed = false;
}
int complete = completeName(argumentList.getCursorArgument(),
argumentList.getArgumentPosition(), candidates,
int complete = completeName(cursorArgument, argumentPosition, candidates,
findAliasesInSQL(argumentList.getArguments()), isColumnAllowed);
if (candidates.size() == 1) {
InterpreterCompletion interpreterCompletion = candidates.get(0);
interpreterCompletion.setName(interpreterCompletion.getName() + " ");
interpreterCompletion.setValue(interpreterCompletion.getValue() + " ");
candidates.set(0, interpreterCompletion);
}
logger.debug("complete:" + complete + ", size:" + candidates.size());
return complete;
@ -406,8 +431,18 @@ public class SqlCompleter {
*/
private int completeTable(String schema, String buffer, int cursor,
List<CharSequence> candidates) {
if (schema == null) {
int res = -1;
Set<CharSequence> candidatesSet = new HashSet<>();
for (StringsCompleter stringsCompleter : tablesCompleters.values()) {
int resTable = stringsCompleter.complete(buffer, cursor, candidatesSet);
res = Math.max(res, resTable);
}
candidates.addAll(candidatesSet);
return res;
}
// Wrong schema
if (!tablesCompleters.containsKey(schema))
if (!tablesCompleters.containsKey(schema) && schema != null)
return -1;
else
return tablesCompleters.get(schema).complete(buffer, cursor, candidates);
@ -420,12 +455,23 @@ public class SqlCompleter {
*/
private int completeColumn(String schema, String table, String buffer, int cursor,
List<CharSequence> candidates) {
if (table == null && schema == null) {
int res = -1;
Set<CharSequence> candidatesSet = new HashSet<>();
for (StringsCompleter stringsCompleter : columnsCompleters.values()) {
int resColumn = stringsCompleter.complete(buffer, cursor, candidatesSet);
res = Math.max(res, resColumn);
}
candidates.addAll(candidatesSet);
return res;
}
// Wrong schema or wrong table
if (!tablesCompleters.containsKey(schema) ||
!columnsCompleters.containsKey(schema + "." + table))
!columnsCompleters.containsKey(schema + "." + table)) {
return -1;
else
} else {
return columnsCompleters.get(schema + "." + table).complete(buffer, cursor, candidates);
}
}
/**
@ -439,30 +485,39 @@ public class SqlCompleter {
public int completeName(String buffer, int cursor, List<InterpreterCompletion> candidates,
Map<String, String> aliases, boolean isColumnAllowed) {
if (buffer == null) buffer = "";
// no need to process after first point after cursor
int nextPointPos = buffer.indexOf('.', cursor);
if (nextPointPos != -1) {
buffer = buffer.substring(0, nextPointPos);
}
// points divide the name to the schema, table and column - find them
int pointPos1 = buffer.indexOf('.');
int pointPos2 = buffer.indexOf('.', pointPos1 + 1);
int pointPos1 = -1;
int pointPos2 = -1;
if (StringUtils.isNotEmpty(buffer)) {
if (buffer.length() > cursor) {
buffer = buffer.substring(0, cursor + 1);
}
pointPos1 = buffer.indexOf('.');
pointPos2 = buffer.indexOf('.', pointPos1 + 1);
}
// find schema and table name if they are
String schema;
String table;
String column;
if (pointPos1 == -1) { // process only schema or keyword case
schema = buffer;
if (pointPos1 == -1) { // process all
List<CharSequence> keywordsCandidates = new ArrayList();
int keywordsRes = completeKeyword(buffer, cursor, keywordsCandidates);
List<CharSequence> schemaCandidates = new ArrayList<>();
int schemaRes = completeSchema(schema, cursor, schemaCandidates);
List<CharSequence> tableCandidates = new ArrayList<>();
List<CharSequence> columnCandidates = new ArrayList<>();
int keywordsRes = completeKeyword(buffer, cursor, keywordsCandidates);
int schemaRes = completeSchema(buffer, cursor, schemaCandidates);
int tableRes = completeTable(null, buffer, cursor, tableCandidates);
int columnRes = -1;
if (isColumnAllowed) {
columnRes = completeColumn(null, null, buffer, cursor, columnCandidates);
}
addCompletions(candidates, keywordsCandidates, CompletionType.keyword.name());
addCompletions(candidates, schemaCandidates, CompletionType.schema.name());
return Math.max(keywordsRes, schemaRes);
addCompletions(candidates, tableCandidates, CompletionType.table.name());
addCompletions(candidates, columnCandidates, CompletionType.column.name());
return NumberUtils.max(new int[]{keywordsRes, schemaRes, tableRes, columnRes});
} else {
schema = buffer.substring(0, pointPos1);
if (aliases.containsKey(schema)) { // process alias case
@ -472,8 +527,8 @@ public class SqlCompleter {
table = alias.substring(pointPos + 1);
column = buffer.substring(pointPos1 + 1);
} else if (pointPos2 == -1) { // process schema.table case
table = buffer.substring(pointPos1 + 1);
List<CharSequence> tableCandidates = new ArrayList();
table = buffer.substring(pointPos1 + 1);
int tableRes = completeTable(schema, table, cursor - pointPos1 - 1, tableCandidates);
addCompletions(candidates, tableCandidates, CompletionType.table.name());
return tableRes;
@ -484,15 +539,15 @@ public class SqlCompleter {
}
// here in case of column
if (isColumnAllowed) {
if (table != null && isColumnAllowed) {
List<CharSequence> columnCandidates = new ArrayList();
int columnRes = completeColumn(schema, table, column, cursor - pointPos2 - 1,
columnCandidates);
addCompletions(candidates, columnCandidates, CompletionType.column.name());
return columnRes;
} else {
return -1;
}
return -1;
}
// test purpose only

View file

@ -294,7 +294,7 @@ public class JDBCInterpreterTest extends BasicJDBCTestCaseAdapter {
jdbcInterpreter.interpret("", interpreterContext);
List<InterpreterCompletion> completionList = jdbcInterpreter.completion("sel", 1);
List<InterpreterCompletion> completionList = jdbcInterpreter.completion("sel", 3);
InterpreterCompletion correctCompletionKeyword = new InterpreterCompletion("select ", "select ", CompletionType.keyword.name());

View file

@ -23,6 +23,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang.StringUtils;
import org.apache.zeppelin.completer.CompletionType;
import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
import org.junit.Assert;
@ -82,7 +83,11 @@ public class SqlCompleterTest {
}
}
private void expectedCompletions(String buffer, int cursor, Set<InterpreterCompletion> expected) {
private void expectedCompletions(String buffer, int cursor,
Set<InterpreterCompletion> expected) {
if (StringUtils.isNotEmpty(buffer) && buffer.length() > cursor) {
buffer = buffer.substring(0, cursor + 1);
}
List<InterpreterCompletion> candidates = new ArrayList<>();
@ -93,7 +98,7 @@ public class SqlCompleterTest {
logger.info(explain);
Assert.assertEquals("Buffer [" + buffer.replace(" ", ".") + "] and Cursor[" + cursor + "] "
+ explain, expected, newHashSet(candidates));
+ explain, expected, newHashSet(candidates));
}
private String explain(String buffer, int cursor, List<InterpreterCompletion> candidates) {
@ -133,7 +138,7 @@ public class SqlCompleterTest {
private CompleterTester tester;
private ArgumentCompleter.WhitespaceArgumentDelimiter delimiter =
new ArgumentCompleter.WhitespaceArgumentDelimiter();
new ArgumentCompleter.WhitespaceArgumentDelimiter();
private SqlCompleter sqlCompleter = new SqlCompleter();
@ -189,7 +194,7 @@ public class SqlCompleterTest {
}
@Test
public void testFindAliasesInSQL_Simple(){
public void testFindAliasesInSQL_Simple() {
String sql = "select * from prod_emart.financial_account a";
Map<String, String> res = sqlCompleter.findAliasesInSQL(delimiter.delimit(sql, 0).getArguments());
assertEquals(1, res.size());
@ -197,7 +202,7 @@ public class SqlCompleterTest {
}
@Test
public void testFindAliasesInSQL_Two(){
public void testFindAliasesInSQL_Two() {
String sql = "select * from prod_dds.financial_account a, prod_dds.customer b";
Map<String, String> res = sqlCompleter.findAliasesInSQL(sqlCompleter.getSqlDelimiter().delimit(sql, 0).getArguments());
assertEquals(2, res.size());
@ -206,7 +211,7 @@ public class SqlCompleterTest {
}
@Test
public void testFindAliasesInSQL_WrongTables(){
public void testFindAliasesInSQL_WrongTables() {
String sql = "select * from prod_ddsxx.financial_account a, prod_dds.customerxx b";
Map<String, String> res = sqlCompleter.findAliasesInSQL(sqlCompleter.getSqlDelimiter().delimit(sql, 0).getArguments());
assertEquals(0, res.size());
@ -218,8 +223,8 @@ public class SqlCompleterTest {
int cursor = 0;
List<InterpreterCompletion> candidates = new ArrayList<>();
Map<String, String> aliases = new HashMap<>();
sqlCompleter.completeName(buffer, cursor, candidates, aliases, false);
assertEquals(9, candidates.size());
sqlCompleter.completeName(buffer, cursor, candidates, aliases, true);
assertEquals(17, candidates.size());
assertTrue(candidates.contains(new InterpreterCompletion("prod_dds", "prod_dds", CompletionType.schema.name())));
assertTrue(candidates.contains(new InterpreterCompletion("prod_emart", "prod_emart", CompletionType.schema.name())));
assertTrue(candidates.contains(new InterpreterCompletion("SUM", "SUM", CompletionType.keyword.name())));
@ -229,6 +234,14 @@ public class SqlCompleterTest {
assertTrue(candidates.contains(new InterpreterCompletion("ORDER", "ORDER", CompletionType.keyword.name())));
assertTrue(candidates.contains(new InterpreterCompletion("LIMIT", "LIMIT", CompletionType.keyword.name())));
assertTrue(candidates.contains(new InterpreterCompletion("FROM", "FROM", CompletionType.keyword.name())));
assertTrue(candidates.contains(new InterpreterCompletion("financial_account", "financial_account", CompletionType.table.name())));
assertTrue(candidates.contains(new InterpreterCompletion("customer", "customer", CompletionType.table.name())));
assertTrue(candidates.contains(new InterpreterCompletion("account_id", "account_id", CompletionType.column.name())));
assertTrue(candidates.contains(new InterpreterCompletion("customer_rk", "customer_rk", CompletionType.column.name())));
assertTrue(candidates.contains(new InterpreterCompletion("account_rk", "account_rk", CompletionType.column.name())));
assertTrue(candidates.contains(new InterpreterCompletion("name", "name", CompletionType.column.name())));
assertTrue(candidates.contains(new InterpreterCompletion("birth_dt", "birth_dt", CompletionType.column.name())));
assertTrue(candidates.contains(new InterpreterCompletion("balance_amt", "balance_amt", CompletionType.column.name())));
}
@Test
@ -251,7 +264,7 @@ public class SqlCompleterTest {
Map<String, String> aliases = new HashMap<>();
sqlCompleter.completeName(buffer, cursor, candidates, aliases, false);
assertEquals(1, candidates.size());
assertTrue(candidates.contains(new InterpreterCompletion("financial_account ", "financial_account ", CompletionType.table.name())));
assertTrue(candidates.contains(new InterpreterCompletion("financial_account", "financial_account", CompletionType.table.name())));
}
@Test
@ -295,15 +308,15 @@ public class SqlCompleterTest {
@Test
public void testSchemaAndTable() {
String buffer = "select * from prod_emart.fi";
tester.buffer(buffer).from(15).to(24).expect(newHashSet(new InterpreterCompletion("prod_emart ", "prod_emart ", CompletionType.schema.name()))).test();
tester.buffer(buffer).from(19).to(23).expect(newHashSet(new InterpreterCompletion("prod_emart ", "prod_emart ", CompletionType.schema.name()))).test();
tester.buffer(buffer).from(25).to(27).expect(newHashSet(new InterpreterCompletion("financial_account ", "financial_account ", CompletionType.table.name()))).test();
}
@Test
public void testEdges() {
String buffer = " ORDER ";
tester.buffer(buffer).from(0).to(7).expect(newHashSet(new InterpreterCompletion("ORDER ", "ORDER ", CompletionType.keyword.name()))).test();
tester.buffer(buffer).from(8).to(15).expect(newHashSet(
tester.buffer(buffer).from(2).to(6).expect(newHashSet(new InterpreterCompletion("ORDER ", "ORDER ", CompletionType.keyword.name()))).test();
tester.buffer(buffer).from(0).to(1).expect(newHashSet(
new InterpreterCompletion("ORDER", "ORDER", CompletionType.keyword.name()),
new InterpreterCompletion("SUBCLASS_ORIGIN", "SUBCLASS_ORIGIN", CompletionType.keyword.name()),
new InterpreterCompletion("SUBSTRING", "SUBSTRING", CompletionType.keyword.name()),
@ -312,29 +325,37 @@ public class SqlCompleterTest {
new InterpreterCompletion("SUM", "SUM", CompletionType.keyword.name()),
new InterpreterCompletion("prod_dds", "prod_dds", CompletionType.schema.name()),
new InterpreterCompletion("SELECT", "SELECT", CompletionType.keyword.name()),
new InterpreterCompletion("FROM", "FROM", CompletionType.keyword.name())
new InterpreterCompletion("FROM", "FROM", CompletionType.keyword.name()),
new InterpreterCompletion("financial_account", "financial_account", CompletionType.table.name()),
new InterpreterCompletion("customer", "customer", CompletionType.table.name()),
new InterpreterCompletion("account_rk", "account_rk", CompletionType.column.name()),
new InterpreterCompletion("account_id", "account_id", CompletionType.column.name()),
new InterpreterCompletion("customer_rk", "customer_rk", CompletionType.column.name()),
new InterpreterCompletion("name", "name", CompletionType.column.name()),
new InterpreterCompletion("birth_dt", "birth_dt", CompletionType.column.name()),
new InterpreterCompletion("balance_amt", "balance_amt", CompletionType.column.name())
)).test();
}
@Test
public void testMultipleWords() {
String buffer = "SELE FRO LIM";
tester.buffer(buffer).from(0).to(4).expect(newHashSet(new InterpreterCompletion("SELECT ", "SELECT ", CompletionType.keyword.name()))).test();
tester.buffer(buffer).from(5).to(8).expect(newHashSet(new InterpreterCompletion("FROM ", "FROM ", CompletionType.keyword.name()))).test();
tester.buffer(buffer).from(1).to(3).expect(newHashSet(new InterpreterCompletion("SELECT ", "SELECT ", CompletionType.keyword.name()))).test();
tester.buffer(buffer).from(6).to(7).expect(newHashSet(new InterpreterCompletion("FROM ", "FROM ", CompletionType.keyword.name()))).test();
tester.buffer(buffer).from(9).to(12).expect(newHashSet(new InterpreterCompletion("LIMIT ", "LIMIT ", CompletionType.keyword.name()))).test();
}
@Test
public void testMultiLineBuffer() {
String buffer = " \n SELE\nFRO";
tester.buffer(buffer).from(0).to(7).expect(newHashSet(new InterpreterCompletion("SELECT ", "SELECT ", CompletionType.keyword.name()))).test();
tester.buffer(buffer).from(8).to(11).expect(newHashSet(new InterpreterCompletion("FROM ", "FROM ", CompletionType.keyword.name()))).test();
tester.buffer(buffer).from(4).to(6).expect(newHashSet(new InterpreterCompletion("SELECT ", "SELECT ", CompletionType.keyword.name()))).test();
tester.buffer(buffer).from(9).to(11).expect(newHashSet(new InterpreterCompletion("FROM ", "FROM ", CompletionType.keyword.name()))).test();
}
@Test
public void testMultipleCompletionSuggestions() {
String buffer = "SU";
tester.buffer(buffer).from(0).to(2).expect(newHashSet(
tester.buffer(buffer).from(1).to(2).expect(newHashSet(
new InterpreterCompletion("SUBCLASS_ORIGIN", "SUBCLASS_ORIGIN", CompletionType.keyword.name()),
new InterpreterCompletion("SUM", "SUM", CompletionType.keyword.name()),
new InterpreterCompletion("SUBSTRING", "SUBSTRING", CompletionType.keyword.name()))

View file

@ -15,7 +15,9 @@
package org.apache.zeppelin.completer;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
@ -26,8 +28,12 @@ import jline.internal.Preconditions;
* Case-insensitive completer for a set of strings.
*/
public class StringsCompleter implements Completer {
private final SortedSet<String> strings = new TreeSet<String>();
private final SortedSet<String> strings = new TreeSet<String>(new Comparator<String>() {
@Override
public int compare(String o1, String o2) {
return o1.compareToIgnoreCase(o2);
}
});
public StringsCompleter() {
}
@ -42,12 +48,19 @@ public class StringsCompleter implements Completer {
}
public int complete(final String buffer, final int cursor, final List<CharSequence> candidates) {
Preconditions.checkNotNull(candidates);
return completeCollection(buffer, cursor, candidates);
}
public int complete(final String buffer, final int cursor, final Set<CharSequence> candidates) {
return completeCollection(buffer, cursor, candidates);
}
private int completeCollection(final String buffer, final int cursor,
final Collection<CharSequence> candidates) {
Preconditions.checkNotNull(candidates);
if (buffer == null) {
candidates.addAll(strings);
}
else {
} else {
String bufferTmp = buffer.toUpperCase();
for (String match : strings.tailSet(buffer)) {
String matchTmp = match.toUpperCase();
@ -59,10 +72,6 @@ public class StringsCompleter implements Completer {
}
}
if (candidates.size() == 1) {
candidates.set(0, candidates.get(0) + " ");
}
return candidates.isEmpty() ? -1 : 0;
}
}