expose roles from ldap + fix substitution code

This commit is contained in:
Herval Freire 2017-07-19 11:23:43 -07:00
parent 87705a9324
commit ed1b0f8502
3 changed files with 212 additions and 35 deletions

View file

@ -33,6 +33,7 @@ import org.apache.shiro.crypto.hash.HashService;
import org.apache.shiro.realm.ldap.JndiLdapRealm;
import org.apache.shiro.realm.ldap.LdapContextFactory;
import org.apache.shiro.realm.ldap.LdapUtils;
import org.apache.shiro.session.Session;
import org.apache.shiro.subject.MutablePrincipalCollection;
import org.apache.shiro.subject.PrincipalCollection;
import org.apache.shiro.util.StringUtils;
@ -246,7 +247,7 @@ public class LdapRealm extends JndiLdapRealm {
* if any LDAP errors occur during the search.
*/
@Override
protected AuthorizationInfo queryForAuthorizationInfo(final PrincipalCollection principals,
public AuthorizationInfo queryForAuthorizationInfo(final PrincipalCollection principals,
final LdapContextFactory ldapContextFactory) throws NamingException {
if (!isAuthorizationEnabled()) {
return null;
@ -286,7 +287,8 @@ public class LdapRealm extends JndiLdapRealm {
LdapContext systemLdapCtx = null;
try {
systemLdapCtx = ldapContextFactory.getSystemLdapContext();
return rolesFor(principals, username, systemLdapCtx, ldapContextFactory);
return rolesFor(principals, username, systemLdapCtx,
ldapContextFactory, SecurityUtils.getSubject().getSession());
} catch (AuthenticationException ae) {
ae.printStackTrace();
return Collections.emptySet();
@ -295,9 +297,9 @@ public class LdapRealm extends JndiLdapRealm {
}
}
private Set<String> rolesFor(PrincipalCollection principals,
protected Set<String> rolesFor(PrincipalCollection principals,
String userNameIn, final LdapContext ldapCtx,
final LdapContextFactory ldapContextFactory) throws NamingException {
final LdapContextFactory ldapContextFactory, Session session) throws NamingException {
final Set<String> roleNames = new HashSet<>();
final Set<String> groupNames = new HashSet<>();
final String userName;
@ -308,14 +310,7 @@ public class LdapRealm extends JndiLdapRealm {
userName = userNameIn;
}
String userDn;
if (userSearchAttributeName == null || userSearchAttributeName.isEmpty()) {
// memberAttributeValuePrefix and memberAttributeValueSuffix
// were computed from memberAttributeValueTemplate
userDn = memberAttributeValuePrefix + userName + memberAttributeValueSuffix;
} else {
userDn = getUserDn(userName);
}
String userDn = getUserDnForSearch(userName);
// Activate paged results
int pageSize = getPagingSize();
@ -364,8 +359,7 @@ public class LdapRealm extends JndiLdapRealm {
// If group search filter is defined in Shiro config, then use it
if (groupSearchFilter != null) {
Matcher matchedPrincipal = matchPrincipal(userDn);
searchFilter = expandTemplate(groupSearchFilter, matchedPrincipal);
searchFilter = expandTemplate(groupSearchFilter, userName);
//searchFilter = String.format("%1$s", groupSearchFilter);
}
if (log.isDebugEnabled()) {
@ -402,8 +396,8 @@ public class LdapRealm extends JndiLdapRealm {
}
// save role names and group names in session so that they can be
// easily looked up outside of this object
SecurityUtils.getSubject().getSession().setAttribute(SUBJECT_USER_ROLES, roleNames);
SecurityUtils.getSubject().getSession().setAttribute(SUBJECT_USER_GROUPS, groupNames);
session.setAttribute(SUBJECT_USER_ROLES, roleNames);
session.setAttribute(SUBJECT_USER_GROUPS, groupNames);
if (!groupNames.isEmpty() && (principals instanceof MutablePrincipalCollection)) {
((MutablePrincipalCollection) principals).addAll(groupNames, getName());
}
@ -413,7 +407,17 @@ public class LdapRealm extends JndiLdapRealm {
return roleNames;
}
private void addRoleIfMember(final String userDn, final SearchResult group,
protected String getUserDnForSearch(String userName) {
if (userSearchAttributeName == null || userSearchAttributeName.isEmpty()) {
// memberAttributeValuePrefix and memberAttributeValueSuffix
// were computed from memberAttributeValueTemplate
return memberDn(userName);
} else {
return getUserDn(userName);
}
}
private void addRoleIfMember(final String userDn, final SearchResult group,
final Set<String> roleNames, final Set<String> groupNames,
final LdapContextFactory ldapContextFactory) throws NamingException {
@ -446,8 +450,9 @@ public class LdapRealm extends JndiLdapRealm {
}
}
} else {
// posix groups' members don' include the entire dn
if (groupObjectClass.equalsIgnoreCase(POSIX_GROUP)) {
attrValue = memberAttributeValuePrefix + attrValue + memberAttributeValueSuffix;
attrValue = memberDn(attrValue);
}
if (userLdapDn.equals(new LdapName(attrValue))) {
groupNames.add(groupName);
@ -474,7 +479,11 @@ public class LdapRealm extends JndiLdapRealm {
}
}
}
private String memberDn(String attrValue) {
return memberAttributeValuePrefix + attrValue + memberAttributeValueSuffix;
}
public Map<String, String> getListRoles() {
Map<String, String> groupToRoles = getRolesByGroup();
Map<String, String> roles = new HashMap<>();
@ -804,7 +813,7 @@ public class LdapRealm extends JndiLdapRealm {
return searchControls;
}
private SearchControls getGroupSearchControls() {
protected SearchControls getGroupSearchControls() {
SearchControls searchControls = SUBTREE_SCOPE;
if ("onelevel".equalsIgnoreCase(groupSearchScope)) {
searchControls = ONELEVEL_SCOPE;
@ -819,13 +828,13 @@ public class LdapRealm extends JndiLdapRealm {
userDnTemplate = template;
}
private Matcher matchPrincipal(final String principal) {
private String matchPrincipal(final String principal) {
Matcher matchedPrincipal = principalPattern.matcher(principal);
if (!matchedPrincipal.matches()) {
throw new IllegalArgumentException("Principal "
+ principal + " does not match " + principalRegex);
}
return matchedPrincipal;
return matchedPrincipal.group();
}
/**
@ -856,7 +865,7 @@ public class LdapRealm extends JndiLdapRealm {
protected String getUserDn(final String principal) throws IllegalArgumentException,
IllegalStateException {
String userDn;
Matcher matchedPrincipal = matchPrincipal(principal);
String matchedPrincipal = matchPrincipal(principal);
String userSearchBase = getUserSearchBase();
String userSearchAttributeName = getUserSearchAttributeName();
@ -938,16 +947,7 @@ public class LdapRealm extends JndiLdapRealm {
getName());
}
private static final String expandTemplate(final String template, final Matcher input) {
String output = template;
Matcher matcher = TEMPLATE_PATTERN.matcher(output);
while (matcher.find()) {
String lookupStr = matcher.group(1);
int lookupIndex = Integer.parseInt(lookupStr);
String lookupValue = input.group(lookupIndex);
output = matcher.replaceFirst(lookupValue == null ? "" : lookupValue);
matcher = TEMPLATE_PATTERN.matcher(output);
}
return output;
protected static final String expandTemplate(final String template, final String input) {
return template.replace(MEMBER_SUBSTITUTION_TOKEN, input);
}
}

View file

@ -26,10 +26,14 @@ import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import javax.naming.NamingException;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.config.IniSecurityManagerFactory;
import org.apache.shiro.mgt.SecurityManager;
import org.apache.shiro.realm.Realm;
import org.apache.shiro.realm.text.IniRealm;
import org.apache.shiro.subject.SimplePrincipalCollection;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.util.ThreadContext;
import org.apache.shiro.web.mgt.DefaultWebSecurityManager;
@ -129,7 +133,15 @@ public class SecurityUtils {
allRoles = ((IniRealm) realm).getIni().get("roles");
break;
} else if (name.equals("org.apache.zeppelin.realm.LdapRealm")) {
allRoles = ((LdapRealm) realm).getListRoles();
try {
AuthorizationInfo auth = ((LdapRealm) realm).queryForAuthorizationInfo(
new SimplePrincipalCollection(subject.getPrincipal(), realm.getName()),
((LdapRealm) realm).getContextFactory()
);
roles = new HashSet<>(auth.getRoles());
} catch (NamingException e) {
log.error("Can't fetch roles", e);
}
break;
} else if (name.equals("org.apache.zeppelin.realm.ActiveDirectoryGroupRealm")) {
allRoles = ((ActiveDirectoryGroupRealm) realm).getListRoles();

View file

@ -0,0 +1,165 @@
package org.apache.zeppelin.realm;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.BasicAttributes;
import javax.naming.directory.SearchResult;
import javax.naming.ldap.LdapContext;
import org.apache.shiro.realm.ldap.LdapContextFactory;
import org.apache.shiro.session.Session;
import org.apache.shiro.subject.SimplePrincipalCollection;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.verify;
public class LdapRealmTest {
@Test
public void testGetUserDn() {
LdapRealm realm = new LdapRealm();
// without a user search filter
realm.setUserSearchFilter(null);
assertEquals(
"foo ",
realm.getUserDn("foo ")
);
// with a user search filter
realm.setUserSearchFilter("memberUid={0}");
assertEquals(
"foo",
realm.getUserDn("foo")
);
}
@Test
public void testExpandTemplate() {
assertEquals(
"uid=foo,cn=users,dc=ods,dc=foo",
LdapRealm.expandTemplate("uid={0},cn=users,dc=ods,dc=foo", "foo")
);
}
@Test
public void getUserDnForSearch() {
LdapRealm realm = new LdapRealm();
realm.setUserSearchAttributeName("uid");
assertEquals(
"foo",
realm.getUserDnForSearch("foo")
);
// using a template
realm.setUserSearchAttributeName(null);
realm.setMemberAttributeValueTemplate("cn={0},ou=people,dc=hadoop,dc=apache");
assertEquals(
"cn=foo,ou=people,dc=hadoop,dc=apache",
realm.getUserDnForSearch("foo")
);
}
@Test
public void testRolesFor() throws NamingException {
LdapRealm realm = new LdapRealm();
realm.setGroupSearchBase("cn=groups,dc=apache");
realm.setGroupObjectClass("posixGroup");
realm.setMemberAttributeValueTemplate("cn={0},ou=people,dc=apache");
HashMap<String, String> rolesByGroups = new HashMap<>();
rolesByGroups.put("group-three", "zeppelin-role");
realm.setRolesByGroup(rolesByGroups);
LdapContextFactory ldapContextFactory = mock(LdapContextFactory.class);
LdapContext ldapCtx = mock(LdapContext.class);
Session session = mock(Session.class);
// expected search results
BasicAttributes group1 = new BasicAttributes();
group1.put(realm.getGroupIdAttribute(), "group-one");
group1.put(realm.getMemberAttribute(), "principal");
// user doesn't belong to this group
BasicAttributes group2 = new BasicAttributes();
group2.put(realm.getGroupIdAttribute(), "group-two");
group2.put(realm.getMemberAttribute(), "someoneelse");
// mapped to a different Zeppelin role
BasicAttributes group3 = new BasicAttributes();
group3.put(realm.getGroupIdAttribute(), "group-three");
group3.put(realm.getMemberAttribute(), "principal");
NamingEnumeration<SearchResult> results = enumerationOf(group1, group2, group3);
when(ldapCtx.search(any(String.class), any(String.class), any())).thenReturn(results);
Set<String> roles = realm.rolesFor(
new SimplePrincipalCollection("principal", "ldapRealm"),
"principal",
ldapCtx,
ldapContextFactory,
session
);
verify(ldapCtx).search(
"cn=groups,dc=apache",
"(objectclass=posixGroup)",
realm.getGroupSearchControls()
);
assertEquals(
new HashSet(Arrays.asList("group-one", "zeppelin-role")),
roles
);
// group search matching enabled
// group search filter supplied
//
}
private NamingEnumeration<SearchResult> enumerationOf(BasicAttributes... attrs) {
Iterator<BasicAttributes> iterator = Arrays.asList(attrs).iterator();
return new NamingEnumeration<SearchResult>() {
@Override
public SearchResult next() throws NamingException {
return nextElement();
}
@Override
public boolean hasMore() throws NamingException {
return iterator.hasNext();
}
@Override
public void close() throws NamingException {
}
@Override
public boolean hasMoreElements() {
return iterator.hasNext();
}
@Override
public SearchResult nextElement() {
BasicAttributes attrs = iterator.next();
return new SearchResult(null, null, attrs);
}
};
}
}