fleet/ee/server/scim/groups.go
Victor Lyuboslavsky 8658608c37
Add SCIM Groups (#27702)
For #27287

This PR adds SCIM Groups to Fleet's SCIM endpoint as a follow on to SCIM
Users. The logic has been manually tested with Okta, and integration
tests will be in the next PR.

# Checklist for submitter
- [x] Added/updated automated tests
- [x] Manual QA for all new/changed functionality
2025-04-02 17:10:40 -05:00

327 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package scim
import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/elimity-com/scim"
"github.com/elimity-com/scim/errors"
"github.com/elimity-com/scim/optional"
"github.com/fleetdm/fleet/v4/server/fleet"
kitlog "github.com/go-kit/log"
"github.com/go-kit/log/level"
)
const (
// Group attributes: https://datatracker.ietf.org/doc/html/rfc7643#section-4.2
displayNameAttr = "displayName"
membersAttr = "members"
)
type GroupHandler struct {
ds fleet.Datastore
logger kitlog.Logger
}
// Compile-time check
var _ scim.ResourceHandler = &GroupHandler{}
func NewGroupHandler(ds fleet.Datastore, logger kitlog.Logger) scim.ResourceHandler {
return &GroupHandler{ds: ds, logger: logger}
}
// Create creates a SCIM group
func (g *GroupHandler) Create(r *http.Request, attributes scim.ResourceAttributes) (scim.Resource, error) {
displayName, err := getRequiredResource[string](attributes, displayNameAttr)
if err != nil {
level.Error(g.logger).Log("msg", "failed to get displayName", "err", err)
return scim.Resource{}, err
}
// Microsofts SCIM implementation (Entra ID) imposes additional constraints—like enforcing uniqueness on a groups
// displayName—that the SCIM spec itself does not mandate.
// In effect, Microsofts implementation diverges from strict SCIM compliance by making displayName behave like a unique key.
// SCIM only mandates that each groups "id" is unique
_, err = g.ds.ScimGroupByDisplayName(r.Context(), displayName)
switch {
case err != nil && !fleet.IsNotFound(err):
level.Error(g.logger).Log("msg", "failed to check for displayName uniqueness", displayNameAttr, displayName, "err", err)
return scim.Resource{}, err
case err == nil:
level.Info(g.logger).Log("msg", "group already exists", displayNameAttr, displayName)
return scim.Resource{}, errors.ScimErrorUniqueness
}
group, err := createGroupFromAttributes(attributes)
if err != nil {
level.Error(g.logger).Log("msg", "failed to create group from attributes", displayNameAttr, displayName, "err", err)
return scim.Resource{}, err
}
group.ID, err = g.ds.CreateScimGroup(r.Context(), group)
if err != nil {
return scim.Resource{}, err
}
return createGroupResource(group), nil
}
func createGroupFromAttributes(attributes scim.ResourceAttributes) (*fleet.ScimGroup, error) {
group := fleet.ScimGroup{}
var err error
group.DisplayName, err = getRequiredResource[string](attributes, displayNameAttr)
if err != nil {
return nil, err
}
group.ExternalID, err = getOptionalResource[string](attributes, externalIdAttr)
if err != nil {
return nil, err
}
// Process members
members, err := getComplexResourceSlice(attributes, membersAttr)
if err != nil {
return nil, err
}
userIDs := make([]uint, 0, len(members))
for _, member := range members {
// Get the value attribute which contains the user ID
valueIntf, ok := member["value"]
if !ok || valueIntf == nil {
continue
}
valueStr, ok := valueIntf.(string)
if !ok {
return nil, errors.ScimErrorBadParams([]string{"value"})
}
// Extract user ID from the value
userID, err := extractUserIDFromValue(valueStr)
if err != nil {
return nil, errors.ScimErrorBadParams([]string{"value"})
}
userIDs = append(userIDs, userID)
}
group.ScimUsers = userIDs
return &group, nil
}
// Get the Scim group by ID. The group id is of the format: group-123
// SCIM resource IDs must be unique across all resources.
func (g *GroupHandler) Get(r *http.Request, id string) (scim.Resource, error) {
idUint, err := extractGroupIDFromValue(id)
if err != nil {
level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
group, err := g.ds.ScimGroupByID(r.Context(), idUint)
switch {
case fleet.IsNotFound(err):
level.Info(g.logger).Log("msg", "failed to find group", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(g.logger).Log("msg", "failed to get group", "id", id, "err", err)
return scim.Resource{}, err
}
return createGroupResource(group), nil
}
func createGroupResource(group *fleet.ScimGroup) scim.Resource {
groupResource := scim.Resource{}
groupResource.ID = scimGroupID(group.ID)
if group.ExternalID != nil {
groupResource.ExternalID = optional.NewString(*group.ExternalID)
}
groupResource.Attributes = scim.ResourceAttributes{}
groupResource.Attributes[displayNameAttr] = group.DisplayName
// Add members if any
if len(group.ScimUsers) > 0 {
members := make([]scim.ResourceAttributes, 0, len(group.ScimUsers))
for _, userID := range group.ScimUsers {
members = append(members, map[string]interface{}{
"value": scimUserID(userID),
"$ref": "Users/" + scimUserID(userID),
"type": "User",
})
}
groupResource.Attributes[membersAttr] = members
}
return groupResource
}
func (g *GroupHandler) GetAll(r *http.Request, params scim.ListRequestParams) (scim.Page, error) {
page := params.StartIndex
if page < 1 {
page = 1
}
count := params.Count
if count > maxResults {
return scim.Page{}, errors.ScimErrorTooMany
}
if count < 1 {
count = maxResults
}
opts := fleet.ScimListOptions{
Page: uint(page), // nolint:gosec // ignore G115
PerPage: uint(count), // nolint:gosec // ignore G115
}
resourceFilter := r.URL.Query().Get("filter")
if resourceFilter != "" {
level.Info(g.logger).Log("msg", "group filter not supported", "filter", resourceFilter)
return scim.Page{}, nil
}
groups, totalResults, err := g.ds.ListScimGroups(r.Context(), opts)
if err != nil {
level.Error(g.logger).Log("msg", "failed to list groups", "err", err)
return scim.Page{}, err
}
result := scim.Page{
TotalResults: int(totalResults), // nolint:gosec // ignore G115
Resources: make([]scim.Resource, 0, len(groups)),
}
for i := range groups {
result.Resources = append(result.Resources, createGroupResource(&groups[i]))
}
return result, nil
}
func (g *GroupHandler) Replace(r *http.Request, id string, attributes scim.ResourceAttributes) (scim.Resource, error) {
idUint, err := extractGroupIDFromValue(id)
if err != nil {
level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
group, err := createGroupFromAttributes(attributes)
if err != nil {
level.Error(g.logger).Log("msg", "failed to create group from attributes", "id", id, "err", err)
return scim.Resource{}, err
}
group.ID = idUint
err = g.ds.ReplaceScimGroup(r.Context(), group)
switch {
case fleet.IsNotFound(err):
level.Info(g.logger).Log("msg", "failed to find group to replace", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(g.logger).Log("msg", "failed to replace group", "id", id, "err", err)
return scim.Resource{}, err
}
return createGroupResource(group), nil
}
func (g *GroupHandler) Delete(r *http.Request, id string) error {
idUint, err := extractGroupIDFromValue(id)
if err != nil {
level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err)
return errors.ScimErrorResourceNotFound(id)
}
err = g.ds.DeleteScimGroup(r.Context(), idUint)
switch {
case fleet.IsNotFound(err):
level.Info(g.logger).Log("msg", "failed to find group to delete", "id", id)
return errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(g.logger).Log("msg", "failed to delete group", "id", id, "err", err)
return err
}
return nil
}
// Patch
// Only supporting replacing the "displayName" attribute.
// Note: Okta does not use PATCH endpoint to update groups (2025/04/01)
func (g *GroupHandler) Patch(r *http.Request, id string, operations []scim.PatchOperation) (scim.Resource, error) {
idUint, err := extractGroupIDFromValue(id)
if err != nil {
level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
group, err := g.ds.ScimGroupByID(r.Context(), idUint)
switch {
case fleet.IsNotFound(err):
level.Info(g.logger).Log("msg", "failed to find group to patch", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(g.logger).Log("msg", "failed to get group to patch", "id", id, "err", err)
return scim.Resource{}, err
}
for _, op := range operations {
if op.Op != "replace" {
level.Info(g.logger).Log("msg", "unsupported patch operation", "op", op.Op)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
switch {
case op.Path == nil:
newValues, ok := op.Value.(map[string]interface{})
if !ok {
level.Info(g.logger).Log("msg", "unsupported patch value", "value", op.Value)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
if len(newValues) != 1 {
level.Info(g.logger).Log("msg", "too many patch values", "value", op.Value)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
displayName, err := getRequiredResource[string](newValues, displayNameAttr)
if err != nil {
level.Info(g.logger).Log("msg", "failed to get active value", "value", op.Value)
return scim.Resource{}, err
}
group.DisplayName = displayName
case op.Path.String() == displayNameAttr:
displayName, ok := op.Value.(string)
if !ok {
level.Error(g.logger).Log("msg", "unsupported 'displayName' patch value", "value", op.Value)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
group.DisplayName = displayName
default:
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
}
err = g.ds.ReplaceScimGroup(r.Context(), group)
switch {
case fleet.IsNotFound(err):
level.Info(g.logger).Log("msg", "failed to find group to patch", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(g.logger).Log("msg", "failed to patch group", "id", id, "err", err)
return scim.Resource{}, err
}
return createGroupResource(group), nil
}
func scimGroupID(groupID uint) string {
return fmt.Sprintf("group-%d", groupID)
}
// extractGroupIDFromValue extracts the group ID from a value like "group-123"
func extractGroupIDFromValue(value string) (uint, error) {
if !strings.HasPrefix(value, "group-") {
return 0, fmt.Errorf("value %q does not match the expected format 'group-<id>'", value)
}
idStr := strings.TrimPrefix(value, "group-")
id, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("failed to parse group ID from value %q: %w", value, err)
}
return uint(id), nil
}