Create connection manager class

This commit is contained in:
conker84 2017-06-28 18:14:00 +02:00
parent 35b4e29de8
commit 8e4690e80c
2 changed files with 175 additions and 96 deletions

View file

@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zeppelin.graph.neo4j;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang.StringUtils;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.resource.Resource;
import org.apache.zeppelin.resource.ResourcePool;
import org.neo4j.driver.v1.AuthToken;
import org.neo4j.driver.v1.Config;
import org.neo4j.driver.v1.Driver;
import org.neo4j.driver.v1.GraphDatabase;
import org.neo4j.driver.v1.Session;
import org.neo4j.driver.v1.StatementResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Neo4j connection manager for Zeppelin.
*/
public class Neo4jConnectionManager {
static final Logger LOGGER = LoggerFactory.getLogger(Neo4jConnectionManager.class);
private static final Pattern PROPERTY_PATTERN = Pattern.compile("\\{\\w+\\}");
private static final String REPLACE_CURLY_BRACKETS = "\\{|\\}";
private static final Pattern $_PATTERN = Pattern.compile("\\$\\w+\\}");
private static final String REPLACE_$ = "\\$";
private Driver driver = null;
private final String neo4jUrl;
private final Config config;
private final AuthToken authToken;
public Neo4jConnectionManager(String neo4jUrl, AuthToken authToken,
Config config) {
this.neo4jUrl = neo4jUrl;
this.authToken = authToken;
this.config = config;
}
private Driver getDriver() {
if (driver == null) {
driver = GraphDatabase.driver(this.neo4jUrl, this.authToken, this.config);
}
return driver;
}
public void open() {
getDriver();
}
public void close() {
getDriver().close();
}
private Session getSession() {
return getDriver().session();
}
public StatementResult execute(String cypherQuery,
InterpreterContext interpreterContext) {
Map<String, Object> params = new HashMap<>();
if (interpreterContext != null) {
ResourcePool resourcePool = interpreterContext.getResourcePool();
Set<String> keys = extractParams(cypherQuery, PROPERTY_PATTERN, REPLACE_CURLY_BRACKETS);
keys.addAll(extractParams(cypherQuery, $_PATTERN, REPLACE_$));
for (String key : keys) {
Resource resource = resourcePool.get(key);
if (resource != null) {
params.put(key, resource.get());
}
}
}
LOGGER.debug("Executing cypher query {} with params {}", cypherQuery, params);
StatementResult result;
try (Session session = getSession()) {
result = params.isEmpty()
? getSession().run(cypherQuery) : getSession().run(cypherQuery, params);
}
return result;
}
public StatementResult execute(String cypherQuery) {
return execute(cypherQuery, null);
}
private Set<String> extractParams(String cypherQuery, Pattern pattern, String replaceChar) {
Matcher matcher = pattern.matcher(cypherQuery);
Set<String> keys = new HashSet<>();
while (matcher.find()) {
keys.add(matcher.group().replaceAll(replaceChar, StringUtils.EMPTY));
}
return keys;
}
}

View file

@ -18,7 +18,6 @@
package org.apache.zeppelin.graph.neo4j;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
@ -26,8 +25,6 @@ import java.util.Map;
import java.util.Map.Entry;
import java.util.Properties;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang.StringUtils;
import org.apache.zeppelin.graph.neo4j.utils.Neo4jConversionUtils;
@ -36,33 +33,25 @@ import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterResult.Code;
import org.apache.zeppelin.interpreter.graph.GraphResult;
import org.apache.zeppelin.resource.Resource;
import org.apache.zeppelin.resource.ResourcePool;
import org.apache.zeppelin.scheduler.Scheduler;
import org.apache.zeppelin.scheduler.SchedulerFactory;
import org.neo4j.driver.internal.types.InternalTypeSystem;
import org.neo4j.driver.internal.util.Iterables;
import org.neo4j.driver.v1.AuthToken;
import org.neo4j.driver.v1.AuthTokens;
import org.neo4j.driver.v1.Config;
import org.neo4j.driver.v1.Driver;
import org.neo4j.driver.v1.GraphDatabase;
import org.neo4j.driver.v1.Record;
import org.neo4j.driver.v1.Session;
import org.neo4j.driver.v1.StatementResult;
import org.neo4j.driver.v1.Value;
import org.neo4j.driver.v1.types.Node;
import org.neo4j.driver.v1.types.Relationship;
import org.neo4j.driver.v1.types.TypeSystem;
import org.neo4j.driver.v1.util.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Neo4j interpreter for Zeppelin.
*/
public class Neo4jCypherInterpreter extends Interpreter {
static final Logger LOGGER = LoggerFactory.getLogger(Neo4jCypherInterpreter.class);
public static final String NEO4J_SERVER_URL = "neo4j.url";
public static final String NEO4J_AUTH_TYPE = "neo4j.auth.type";
public static final String NEO4J_AUTH_USER = "neo4j.auth.user";
@ -73,12 +62,6 @@ public class Neo4jCypherInterpreter extends Interpreter {
public static final String NEW_LINE = "\n";
public static final String TAB = "\t";
private static final Pattern PROPERTY_PATTERN = Pattern.compile("\\{\\w+\\}");
private static final String REPLACE_CURLY_BRACKETS = "\\{|\\}";
private static final Pattern $_PATTERN = Pattern.compile("\\$\\w+\\}");
private static final String REPLACE_$ = "\\$";
private static final String MAP_KEY_TEMPLATE = "%s.%s";
private static final String ARAY_KEY_TEMPLATE = "%s[%d]";
@ -89,47 +72,46 @@ public class Neo4jCypherInterpreter extends Interpreter {
*/
public enum Neo4jAuthType {NONE, BASIC}
private Driver driver = null;
private Map<String, String> labels;
private Set<String> types;
private final Neo4jConnectionManager neo4jConnectionManager;
public Neo4jCypherInterpreter(Properties properties) {
super(properties);
}
private Driver getDriver() {
if (driver == null) {
Config config = Config.build()
.withMaxIdleSessions(Integer.parseInt(getProperty(NEO4J_MAX_CONCURRENCY)))
.toConfig();
String authType = getProperty(NEO4J_AUTH_TYPE);
AuthToken authToken = null;
switch (Neo4jAuthType.valueOf(authType.toUpperCase())) {
case BASIC:
authToken = AuthTokens.basic(getProperty(NEO4J_AUTH_USER),
getProperty(NEO4J_AUTH_PASSWORD));
break;
case NONE:
authToken = AuthTokens.none();
break;
default:
throw new RuntimeException("Neo4j authentication type not supported");
}
driver = GraphDatabase.driver(getProperty(NEO4J_SERVER_URL), authToken, config);
Config config = Config.build()
.withMaxIdleSessions(Integer.parseInt(getProperty(NEO4J_MAX_CONCURRENCY)))
.toConfig();
String authType = getProperty(NEO4J_AUTH_TYPE);
AuthToken authToken = null;
switch (Neo4jAuthType.valueOf(authType.toUpperCase())) {
case BASIC:
String username = getProperty(NEO4J_AUTH_USER);
String password = getProperty(NEO4J_AUTH_PASSWORD);
logger.debug("Creating a BASIC authentication to neo4j with user '{}' and password '{}'",
username, password);
authToken = AuthTokens.basic(username, password);
break;
case NONE:
logger.debug("Creating NONE authentication");
authToken = AuthTokens.none();
break;
default:
throw new RuntimeException("Neo4j authentication type not supported");
}
return driver;
this.neo4jConnectionManager = new Neo4jConnectionManager(
getProperty(NEO4J_SERVER_URL), authToken, config);
}
@Override
public void open() {
getDriver();
this.neo4jConnectionManager.open();
}
@Override
public void close() {
getDriver().close();
this.neo4jConnectionManager.close();
}
public Map<String, String> getLabels(boolean refresh) {
@ -137,19 +119,17 @@ public class Neo4jCypherInterpreter extends Interpreter {
Map<String, String> old = labels == null ?
new LinkedHashMap<String, String>() : new LinkedHashMap<>(labels);
labels = new LinkedHashMap<>();
try (Session session = getDriver().session()) {
StatementResult result = session.run("CALL db.labels()");
Set<String> colors = new HashSet<>();
while (result.hasNext()) {
Record record = result.next();
String label = record.get("label").asString();
String color = old.get(label);
while (color == null || colors.contains(color)) {
color = Neo4jConversionUtils.getRandomLabelColor();
}
colors.add(color);
labels.put(label, color);
StatementResult result = this.neo4jConnectionManager.execute("CALL db.labels()");
Set<String> colors = new HashSet<>();
while (result.hasNext()) {
Record record = result.next();
String label = record.get("label").asString();
String color = old.get(label);
while (color == null || colors.contains(color)) {
color = Neo4jConversionUtils.getRandomLabelColor();
}
colors.add(color);
labels.put(label, color);
}
}
return labels;
@ -158,12 +138,10 @@ public class Neo4jCypherInterpreter extends Interpreter {
private Set<String> getTypes(boolean refresh) {
if (types == null || refresh) {
types = new HashSet<>();
try (Session session = getDriver().session()) {
StatementResult result = session.run("CALL db.relationshipTypes()");
while (result.hasNext()) {
Record record = result.next();
types.add(record.get("relationshipType").asString());
}
StatementResult result = this.neo4jConnectionManager.execute("CALL db.relationshipTypes()");
while (result.hasNext()) {
Record record = result.next();
types.add(record.get("relationshipType").asString());
}
}
return types;
@ -175,8 +153,9 @@ public class Neo4jCypherInterpreter extends Interpreter {
if (StringUtils.isEmpty(cypherQuery)) {
return new InterpreterResult(Code.ERROR, "Cypher query is Empty");
}
try (Session session = getDriver().session()){
StatementResult result = execute(session, cypherQuery, interpreterContext);
try {
StatementResult result = this.neo4jConnectionManager.execute(cypherQuery,
interpreterContext);
Set<Node> nodes = new HashSet<>();
Set<Relationship> relationships = new HashSet<>();
List<String> columns = new ArrayList<>();
@ -186,24 +165,26 @@ public class Neo4jCypherInterpreter extends Interpreter {
List<Pair<String, Value>> fields = record.fields();
List<String> line = new ArrayList<>();
for (Pair<String, Value> field : fields) {
if (field.value().hasType(session.typeSystem().NODE())) {
if (field.value().hasType(InternalTypeSystem.TYPE_SYSTEM.NODE())) {
nodes.add(field.value().asNode());
} else if (field.value().hasType(session.typeSystem().RELATIONSHIP())) {
} else if (field.value().hasType(InternalTypeSystem.TYPE_SYSTEM.RELATIONSHIP())) {
relationships.add(field.value().asRelationship());
} else if (field.value().hasType(session.typeSystem().PATH())) {
} else if (field.value().hasType(InternalTypeSystem.TYPE_SYSTEM.PATH())) {
nodes.addAll(Iterables.asList(field.value().asPath().nodes()));
relationships.addAll(Iterables.asList(field.value().asPath().relationships()));
} else if (field.value().hasType(session.typeSystem().LIST())) {
} else if (field.value().hasType(InternalTypeSystem.TYPE_SYSTEM.LIST())) {
List<Object> list = field.value().asList();
for (Object elem : list) {
List<String> lineList = new ArrayList<>();
setTabularResult(field.key(), elem, columns, lineList, session.typeSystem());
setTabularResult(field.key(), elem, columns, lineList,
InternalTypeSystem.TYPE_SYSTEM);
if (!lineList.isEmpty()) {
lines.add(lineList);
}
}
} else {
setTabularResult(field.key(), field.value(), columns, line, session.typeSystem());
setTabularResult(field.key(), field.value(), columns, line,
InternalTypeSystem.TYPE_SYSTEM);
}
}
if (!line.isEmpty()) {
@ -253,7 +234,8 @@ public class Neo4jCypherInterpreter extends Interpreter {
List<Object> list = (List<Object>) obj;
for (int i = 0; i < list.size(); i++) {
Object elem = list.get(i);
setTabularResult(String.format(ARAY_KEY_TEMPLATE, key, i), elem, columns, line, typeSystem);
setTabularResult(String.format(ARAY_KEY_TEMPLATE, key, i), elem, columns,
line, typeSystem);
}
} else {
addLine(key, columns, line, obj);
@ -273,31 +255,6 @@ public class Neo4jCypherInterpreter extends Interpreter {
line.set(position, value == null ? null : value.toString());
}
private StatementResult execute(Session session, String cypherQuery,
InterpreterContext interpreterContext) {
Map<String, Object> params = new HashMap<>();
ResourcePool resourcePool = interpreterContext.getResourcePool();
Set<String> keys = extractParams(cypherQuery, PROPERTY_PATTERN, REPLACE_CURLY_BRACKETS);
keys.addAll(extractParams(cypherQuery, $_PATTERN, REPLACE_$));
for (String key : keys) {
Resource resource = resourcePool.get(key);
if (resource != null) {
params.put(key, resource.get());
}
}
logger.info("Executing cypher query {} with params {}", cypherQuery, params);
return params.isEmpty() ? session.run(cypherQuery) : session.run(cypherQuery, params);
}
private Set<String> extractParams(String cypherQuery, Pattern pattern, String replaceChar) {
Matcher matcher = pattern.matcher(cypherQuery);
Set<String> keys = new HashSet<>();
while (matcher.find()) {
keys.add(matcher.group().replaceAll(replaceChar, StringUtils.EMPTY));
}
return keys;
}
private InterpreterResult renderTable(List<String> cols, List<List<String>> lines) {
logger.info("Executing renderTable method");
StringBuilder msg = new StringBuilder(TABLE);