package scim import ( "context" "fmt" "net/http" "net/url" "strconv" "strings" "github.com/elimity-com/scim" "github.com/elimity-com/scim/errors" "github.com/elimity-com/scim/optional" "github.com/fleetdm/fleet/v4/server/fleet" kitlog "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/scim2/filter-parser/v2" ) const ( // Group attributes: https://datatracker.ietf.org/doc/html/rfc7643#section-4.2 displayNameAttr = "displayName" membersAttr = "members" ) type GroupHandler struct { ds fleet.Datastore logger kitlog.Logger } // Compile-time check var _ scim.ResourceHandler = &GroupHandler{} func NewGroupHandler(ds fleet.Datastore, logger kitlog.Logger) scim.ResourceHandler { return &GroupHandler{ds: ds, logger: logger} } // Create creates a SCIM group func (g *GroupHandler) Create(r *http.Request, attributes scim.ResourceAttributes) (scim.Resource, error) { displayName, err := getRequiredResource[string](attributes, displayNameAttr) if err != nil { level.Error(g.logger).Log("msg", "failed to get displayName", "err", err) return scim.Resource{}, err } // Microsoft’s SCIM implementation (Entra ID) imposes additional constraints—like enforcing uniqueness on a group’s // displayName—that the SCIM spec itself does not mandate. // In effect, Microsoft’s implementation diverges from strict SCIM compliance by making displayName behave like a unique key. // SCIM only mandates that each group’s "id" is unique _, err = g.ds.ScimGroupByDisplayName(r.Context(), displayName) switch { case err != nil && !fleet.IsNotFound(err): level.Error(g.logger).Log("msg", "failed to check for displayName uniqueness", displayNameAttr, displayName, "err", err) return scim.Resource{}, err case err == nil: level.Info(g.logger).Log("msg", "group already exists", displayNameAttr, displayName) return scim.Resource{}, errors.ScimErrorUniqueness } group, err := createGroupFromAttributes(attributes) if err != nil { level.Error(g.logger).Log("msg", "failed to create group from attributes", displayNameAttr, displayName, "err", err) return scim.Resource{}, err } group.ID, err = g.ds.CreateScimGroup(r.Context(), group) if err != nil { return scim.Resource{}, err } return createGroupResource(group), nil } func createGroupFromAttributes(attributes scim.ResourceAttributes) (*fleet.ScimGroup, error) { group := fleet.ScimGroup{} var err error group.DisplayName, err = getRequiredResource[string](attributes, displayNameAttr) if err != nil { return nil, err } group.ExternalID, err = getOptionalResource[string](attributes, externalIdAttr) if err != nil { return nil, err } // Process members members, err := getComplexResourceSlice(attributes, membersAttr) if err != nil { return nil, err } userIDs := make([]uint, 0, len(members)) for _, member := range members { // Get the value attribute which contains the user ID valueIntf, ok := member["value"] if !ok || valueIntf == nil { continue } valueStr, ok := valueIntf.(string) if !ok { return nil, errors.ScimErrorBadParams([]string{"value"}) } // Extract user ID from the value userID, err := extractUserIDFromValue(valueStr) if err != nil { return nil, errors.ScimErrorBadParams([]string{"value"}) } userIDs = append(userIDs, userID) } group.ScimUsers = userIDs return &group, nil } // areMembersExcluded checks if the members attribute is excluded in the request func areMembersExcluded(r *http.Request) bool { excludedAttrs := r.URL.Query().Get("excludedAttributes") if excludedAttrs == "" { return false } // Split the excluded attributes by comma attrs := strings.Split(excludedAttrs, ",") for _, attr := range attrs { // Trim spaces and check if it's "members" if strings.TrimSpace(attr) == membersAttr { return true } } return false } // Get the Scim group by ID. The group id is of the format: group-123 // SCIM resource IDs must be unique across all resources. func (g *GroupHandler) Get(r *http.Request, id string) (scim.Resource, error) { idUint, err := extractGroupIDFromValue(id) if err != nil { level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err) return scim.Resource{}, errors.ScimErrorResourceNotFound(id) } group, err := g.ds.ScimGroupByID(r.Context(), idUint, areMembersExcluded(r)) switch { case fleet.IsNotFound(err): level.Info(g.logger).Log("msg", "failed to find group", "id", id) return scim.Resource{}, errors.ScimErrorResourceNotFound(id) case err != nil: level.Error(g.logger).Log("msg", "failed to get group", "id", id, "err", err) return scim.Resource{}, err } return createGroupResource(group), nil } func createGroupResource(group *fleet.ScimGroup) scim.Resource { groupResource := scim.Resource{} groupResource.ID = scimGroupID(group.ID) if group.ExternalID != nil { groupResource.ExternalID = optional.NewString(*group.ExternalID) } groupResource.Attributes = scim.ResourceAttributes{} groupResource.Attributes[displayNameAttr] = group.DisplayName // Add members if any if len(group.ScimUsers) > 0 { members := make([]scim.ResourceAttributes, 0, len(group.ScimUsers)) for _, userID := range group.ScimUsers { members = append(members, map[string]interface{}{ "value": scimUserID(userID), "type": "User", }) } groupResource.Attributes[membersAttr] = members } return groupResource } // GetAll // Pagination is 1-indexed on the startIndex. The startIndex is the index of the resource (not the index of the page), per RFC7644. func (g *GroupHandler) GetAll(r *http.Request, params scim.ListRequestParams) (scim.Page, error) { startIndex := params.StartIndex if startIndex < 1 { startIndex = 1 } count := params.Count if count > maxResults { return scim.Page{}, errors.ScimErrorTooMany } if count < 1 { count = maxResults } opts := fleet.ScimGroupsListOptions{ ScimListOptions: fleet.ScimListOptions{ StartIndex: uint(startIndex), // nolint:gosec // ignore G115 PerPage: uint(count), // nolint:gosec // ignore G115 }, ExcludeUsers: areMembersExcluded(r), } resourceFilter := r.URL.Query().Get("filter") if resourceFilter != "" { expr, err := filter.ParseAttrExp([]byte(resourceFilter)) if err != nil { level.Error(g.logger).Log("msg", "failed to parse filter", "filter", resourceFilter, "err", err) return scim.Page{}, errors.ScimErrorInvalidFilter } if !strings.EqualFold(expr.AttributePath.String(), "displayName") || expr.Operator != "eq" { level.Info(g.logger).Log("msg", "unsupported filter", "filter", resourceFilter) return scim.Page{}, nil } displayName, ok := expr.CompareValue.(string) if !ok { level.Error(g.logger).Log("msg", "unsupported value", "value", expr.CompareValue) return scim.Page{}, nil } // Decode URL-encoded characters displayName, err = url.QueryUnescape(displayName) if err != nil { level.Error(g.logger).Log("msg", "failed to decode displayName", "displayName", displayName, "err", err) return scim.Page{}, nil } opts.DisplayNameFilter = &displayName } groups, totalResults, err := g.ds.ListScimGroups(r.Context(), opts) if err != nil { level.Error(g.logger).Log("msg", "failed to list groups", "err", err) return scim.Page{}, err } result := scim.Page{ TotalResults: int(totalResults), // nolint:gosec // ignore G115 Resources: make([]scim.Resource, 0, len(groups)), } for i := range groups { result.Resources = append(result.Resources, createGroupResource(&groups[i])) } return result, nil } func (g *GroupHandler) Replace(r *http.Request, id string, attributes scim.ResourceAttributes) (scim.Resource, error) { idUint, err := extractGroupIDFromValue(id) if err != nil { level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err) return scim.Resource{}, errors.ScimErrorResourceNotFound(id) } group, err := createGroupFromAttributes(attributes) if err != nil { level.Error(g.logger).Log("msg", "failed to create group from attributes", "id", id, "err", err) return scim.Resource{}, err } group.ID = idUint // Display name is unique to comply with Entra ID requirements, // so we must check if another group already exists with that display name to return a clear error groupWithSameDisplayName, err := g.ds.ScimGroupByDisplayName(r.Context(), group.DisplayName) switch { case err != nil && !fleet.IsNotFound(err): level.Error(g.logger).Log("msg", "failed to check for displayName uniqueness", displayNameAttr, group.DisplayName, "err", err) return scim.Resource{}, err case err == nil && group.ID != groupWithSameDisplayName.ID: level.Info(g.logger).Log("msg", "group already exists with this displayName", displayNameAttr, group.DisplayName) return scim.Resource{}, errors.ScimErrorUniqueness // Otherwise, we assume that we are replacing the displayName with this operation. } err = g.ds.ReplaceScimGroup(r.Context(), group) switch { case fleet.IsNotFound(err): level.Info(g.logger).Log("msg", "failed to find group to replace", "id", id) return scim.Resource{}, errors.ScimErrorResourceNotFound(id) case err != nil: level.Error(g.logger).Log("msg", "failed to replace group", "id", id, "err", err) return scim.Resource{}, err } return createGroupResource(group), nil } func (g *GroupHandler) Delete(r *http.Request, id string) error { idUint, err := extractGroupIDFromValue(id) if err != nil { level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err) return errors.ScimErrorResourceNotFound(id) } err = g.ds.DeleteScimGroup(r.Context(), idUint) switch { case fleet.IsNotFound(err): level.Info(g.logger).Log("msg", "failed to find group to delete", "id", id) return errors.ScimErrorResourceNotFound(id) case err != nil: level.Error(g.logger).Log("msg", "failed to delete group", "id", id, "err", err) return err } return nil } // Patch // Supporting add/replace/remove operations for "displayName", "externalId", and "members" attributes. func (g *GroupHandler) Patch(r *http.Request, id string, operations []scim.PatchOperation) (scim.Resource, error) { idUint, err := extractGroupIDFromValue(id) if err != nil { level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err) return scim.Resource{}, errors.ScimErrorResourceNotFound(id) } group, err := g.ds.ScimGroupByID(r.Context(), idUint, false) switch { case fleet.IsNotFound(err): level.Info(g.logger).Log("msg", "failed to find group to patch", "id", id) return scim.Resource{}, errors.ScimErrorResourceNotFound(id) case err != nil: level.Error(g.logger).Log("msg", "failed to get group to patch", "id", id, "err", err) return scim.Resource{}, err } for _, op := range operations { if op.Op != scim.PatchOperationAdd && op.Op != scim.PatchOperationReplace && op.Op != scim.PatchOperationRemove { level.Info(g.logger).Log("msg", "unsupported patch operation", "op", op.Op) return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)}) } switch { case op.Path == nil: if op.Op == scim.PatchOperationRemove { level.Info(g.logger).Log("msg", "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op) return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)}) } newValues, ok := op.Value.(map[string]interface{}) if !ok { level.Info(g.logger).Log("msg", "unsupported patch value", "value", op.Value) return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)}) } for k, v := range newValues { switch k { case externalIdAttr: err = g.patchExternalId(op.Op, v, group) if err != nil { return scim.Resource{}, err } case displayNameAttr: err = g.patchDisplayName(op.Op, v, group) if err != nil { return scim.Resource{}, err } case membersAttr: err = g.patchMembers(r.Context(), op.Op, v, group) if err != nil { return scim.Resource{}, err } default: level.Info(g.logger).Log("msg", "unsupported patch value field", "field", k) return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)}) } } case op.Path.String() == externalIdAttr: err = g.patchExternalId(op.Op, op.Value, group) if err != nil { return scim.Resource{}, err } case op.Path.String() == displayNameAttr: err = g.patchDisplayName(op.Op, op.Value, group) if err != nil { return scim.Resource{}, err } case op.Path.String() == membersAttr: err = g.patchMembers(r.Context(), op.Op, op.Value, group) if err != nil { return scim.Resource{}, err } case op.Path.AttributePath.String() == membersAttr: err = g.patchMembersWithPathFiltering(r.Context(), op, group) if err != nil { return scim.Resource{}, err } default: level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path) return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)}) } } if len(operations) != 0 { err = g.ds.ReplaceScimGroup(r.Context(), group) switch { case fleet.IsNotFound(err): level.Info(g.logger).Log("msg", "failed to find group to patch", "id", id) return scim.Resource{}, errors.ScimErrorResourceNotFound(id) case err != nil: level.Error(g.logger).Log("msg", "failed to patch group", "id", id, "err", err) return scim.Resource{}, err } } return createGroupResource(group), nil } func (g *GroupHandler) patchExternalId(op string, v interface{}, group *fleet.ScimGroup) error { if op == scim.PatchOperationRemove || v == nil { group.ExternalID = nil return nil } externalId, ok := v.(string) if !ok { level.Info(g.logger).Log("msg", fmt.Sprintf("unsupported '%s' value", externalIdAttr), "value", v) return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)}) } group.ExternalID = &externalId return nil } func (g *GroupHandler) patchDisplayName(op string, v interface{}, group *fleet.ScimGroup) error { if op == scim.PatchOperationRemove { level.Info(g.logger).Log("msg", "cannot remove required attribute", "attribute", displayNameAttr) return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)}) } displayName, ok := v.(string) if !ok { level.Info(g.logger).Log("msg", fmt.Sprintf("unsupported '%s' value", displayNameAttr), "value", v) return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)}) } if displayName == "" { level.Info(g.logger).Log("msg", fmt.Sprintf("'%s' cannot be empty", displayNameAttr), "value", v) return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)}) } group.DisplayName = displayName return nil } // patchMembers handles add/replace/remove operations for the members attribute func (g *GroupHandler) patchMembers(ctx context.Context, op string, v interface{}, group *fleet.ScimGroup) error { if op == scim.PatchOperationRemove { // Remove all members group.ScimUsers = []uint{} return nil } // For add and replace operations, we need to extract the member IDs var membersList []interface{} // Handle different value formats switch val := v.(type) { case []interface{}: // Direct array of members membersList = val case map[string]interface{}: // Single member as a map membersList = []interface{}{val} case []map[string]interface{}: // Array of member maps for _, m := range val { membersList = append(membersList, m) } default: level.Info(g.logger).Log("msg", "unsupported members value format", "value", v) return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)}) } // Process the members userIDs := make([]uint, 0, len(membersList)) valueStrings := make([]string, 0, len(membersList)) for _, memberIntf := range membersList { member, ok := memberIntf.(map[string]interface{}) if !ok { level.Info(g.logger).Log("msg", "member must be an object", "member", memberIntf) return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", memberIntf)}) } // Get the value attribute which contains the user ID valueIntf, ok := member["value"] if !ok || valueIntf == nil { level.Info(g.logger).Log("msg", "member missing value attribute", "member", member) return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", member)}) } valueStr, ok := valueIntf.(string) if !ok { level.Info(g.logger).Log("msg", "member value must be a string", "value", valueIntf) return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", valueIntf)}) } valueStrings = append(valueStrings, valueStr) // Extract user ID from the value userID, err := extractUserIDFromValue(valueStr) if err != nil { level.Info(g.logger).Log("msg", "invalid user ID format", "value", valueStr, "err", err) return errors.ScimErrorBadParams([]string{valueStr}) } userIDs = append(userIDs, userID) } // Verify all users exist in a single database call if len(userIDs) > 0 { allExist, err := g.ds.ScimUsersExist(ctx, userIDs) if err != nil { level.Error(g.logger).Log("msg", "error checking users existence", "err", err) return err } if !allExist { level.Info(g.logger).Log("msg", "one or more users not found", "userIDs", userIDs) return errors.ScimErrorBadParams(valueStrings) } } // For add operation, append to existing members if op == scim.PatchOperationAdd { // Create a map to track existing user IDs to avoid duplicates existingUsers := make(map[uint]bool) for _, id := range group.ScimUsers { existingUsers[id] = true } // Add new users that don't already exist in the group for _, id := range userIDs { if !existingUsers[id] { group.ScimUsers = append(group.ScimUsers, id) existingUsers[id] = true } } } else { // For replace operation, replace all members group.ScimUsers = userIDs // FIXME: List should be deduplicated by us? See https://github.com/fleetdm/fleet/issues/30086 } return nil } // patchMembersWithPathFiltering handles patch operations with path filtering for members // This supports paths like members[value eq "422"] for add/replace/remove operations func (g *GroupHandler) patchMembersWithPathFiltering(ctx context.Context, op scim.PatchOperation, group *fleet.ScimGroup) error { memberID, err := g.getMemberID(op) if err != nil { return err } // Check if the member exists in the group memberFound := false var memberIndex int for i, id := range group.ScimUsers { if id == memberID { memberIndex = i memberFound = true break } } // For remove operations, remove the member if found if op.Op == scim.PatchOperationRemove { if !memberFound { level.Info(g.logger).Log("msg", "member not found in group", "member_id", memberID, "op", fmt.Sprintf("%v", op)) // The member may have been removed already from this group. For example, if the member was deleted. return nil } group.ScimUsers = append(group.ScimUsers[:memberIndex], group.ScimUsers[memberIndex+1:]...) return nil } // For add operations, add the member if not found if op.Op == scim.PatchOperationAdd && !memberFound { // Verify the user exists userExists, err := g.ds.ScimUsersExist(ctx, []uint{memberID}) if err != nil { level.Error(g.logger).Log("msg", "error checking user existence", "err", err) return err } if !userExists { level.Info(g.logger).Log("msg", "user not found", "user_id", memberID) return errors.ScimErrorBadParams([]string{scimUserID(memberID)}) } group.ScimUsers = append(group.ScimUsers, memberID) return nil } // For replace operations with a value if op.Op == scim.PatchOperationReplace { if !memberFound { level.Info(g.logger).Log("msg", "member not found for replace operation", "members.value", memberID, "op", fmt.Sprintf("%v", op)) return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)}) } // If the value is nil or an empty object, remove the member if op.Value == nil { group.ScimUsers = append(group.ScimUsers[:memberIndex], group.ScimUsers[memberIndex+1:]...) return nil } // Otherwise, we don't change anything since we're already filtering by the member ID // and there are no other attributes to modify for a member return nil } return nil } // getMemberID extracts the member ID from a path expression like members[value eq "422"] func (g *GroupHandler) getMemberID(op scim.PatchOperation) (uint, error) { attrExpression, ok := op.Path.ValueExpression.(*filter.AttributeExpression) if !ok { level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path) return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)}) } // Only matching by member value (user ID) is supported if attrExpression.AttributePath.String() != valueAttr || attrExpression.Operator != filter.EQ { level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path, "expression", attrExpression.AttributePath.String()) return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)}) } memberIDStr, ok := attrExpression.CompareValue.(string) if !ok { level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path, "compare_value", attrExpression.CompareValue) return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)}) } // Extract user ID from the value userID, err := extractUserIDFromValue(memberIDStr) if err != nil { level.Info(g.logger).Log("msg", "invalid user ID format", "value", memberIDStr, "err", err) return 0, errors.ScimErrorBadParams([]string{memberIDStr}) } return userID, nil } func scimGroupID(groupID uint) string { return fmt.Sprintf("group-%d", groupID) } // extractGroupIDFromValue extracts the group ID from a value like "group-123" func extractGroupIDFromValue(value string) (uint, error) { if !strings.HasPrefix(value, "group-") { return 0, fmt.Errorf("value %q does not match the expected format 'group-'", value) } idStr := strings.TrimPrefix(value, "group-") id, err := strconv.ParseUint(idStr, 10, 64) if err != nil { return 0, fmt.Errorf("failed to parse group ID from value %q: %w", value, err) } return uint(id), nil }