SCIM Entra ID support (#28832)

For #28196

This PR adds full patching for SCIM Users and Groups, and adds the
ability to filter Groups by displayName.

The changes have been tested with [Entra ID SCIM
Validator](67dfd91c0c/docs/Contributing/SCIM-integration.md (entra-id-integration))
and Okta SCIM 2.0 SPEC Test (to make sure we didn't break Okta).

# Checklist for submitter
- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
- [x] Added/updated automated tests
- [x] A detailed QA plan exists on the associated ticket (if it isn't
there, work with the product group's QA engineer to add it)
- [x] Manual QA for all new/changed functionality
This commit is contained in:
Victor Lyuboslavsky 2025-05-08 13:02:49 -05:00 committed by GitHub
parent 89c0386572
commit 6f9030ee3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 2522 additions and 215 deletions

View file

@ -0,0 +1 @@
Added ability to sync end user's IdP information with Microsoft Entra ID using SCIM protocol.

View file

@ -51,12 +51,20 @@ Run test using [Runscope](https://www.runscope.com/). See [instructions](https:/
### Testing Entra ID integration
Use [scimvalidator.microsoft.com](https://scimvalidator.microsoft.com/). Only test the attributes that we have implemented. To see our supported attributes, check the schema:
Use [scimvalidator.microsoft.com](https://scimvalidator.microsoft.com/). Only test the attributes that we have implemented.
![SCIM-Entra-ID-Validator-User-attributes.png](assets/SCIM-Entra-ID-Validator-User-attributes.png)
![SCIM-Entra-ID-Validator-Group-attributes.png](assets/SCIM-Entra-ID-Validator-Group-attributes.png)
To see our supported attributes, check the schema:
```
GET https://localhost:8080/api/latest/fleet/scim/Schemas
```
Results (2025/05/06)
![SCIM-Entra-ID-Validator-results.png](assets/SCIM-Entra-ID-Validator-results.png)
## Authentication
We use same authentication as API. HTTP header: `Authorization: Bearer xyz`

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 257 KiB

File diff suppressed because it is too large Load diff

View file

@ -1,8 +1,10 @@
package scim
import (
"context"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
@ -12,6 +14,7 @@ import (
"github.com/fleetdm/fleet/v4/server/fleet"
kitlog "github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/scim2/filter-parser/v2"
)
const (
@ -145,7 +148,6 @@ func createGroupResource(group *fleet.ScimGroup) scim.Resource {
for _, userID := range group.ScimUsers {
members = append(members, map[string]interface{}{
"value": scimUserID(userID),
"$ref": "Users/" + scimUserID(userID),
"type": "User",
})
}
@ -170,15 +172,37 @@ func (g *GroupHandler) GetAll(r *http.Request, params scim.ListRequestParams) (s
count = maxResults
}
opts := fleet.ScimListOptions{
StartIndex: uint(startIndex), // nolint:gosec // ignore G115
PerPage: uint(count), // nolint:gosec // ignore G115
opts := fleet.ScimGroupsListOptions{
ScimListOptions: fleet.ScimListOptions{
StartIndex: uint(startIndex), // nolint:gosec // ignore G115
PerPage: uint(count), // nolint:gosec // ignore G115
},
}
resourceFilter := r.URL.Query().Get("filter")
if resourceFilter != "" {
level.Info(g.logger).Log("msg", "group filter not supported", "filter", resourceFilter)
return scim.Page{}, nil
expr, err := filter.ParseAttrExp([]byte(resourceFilter))
if err != nil {
level.Error(g.logger).Log("msg", "failed to parse filter", "filter", resourceFilter, "err", err)
return scim.Page{}, errors.ScimErrorInvalidFilter
}
if !strings.EqualFold(expr.AttributePath.String(), "displayName") || expr.Operator != "eq" {
level.Info(g.logger).Log("msg", "unsupported filter", "filter", resourceFilter)
return scim.Page{}, nil
}
displayName, ok := expr.CompareValue.(string)
if !ok {
level.Error(g.logger).Log("msg", "unsupported value", "value", expr.CompareValue)
return scim.Page{}, nil
}
// Decode URL-encoded characters
displayName, err = url.QueryUnescape(displayName)
if err != nil {
level.Error(g.logger).Log("msg", "failed to decode displayName", "displayName", displayName, "err", err)
return scim.Page{}, nil
}
opts.DisplayNameFilter = &displayName
}
groups, totalResults, err := g.ds.ListScimGroups(r.Context(), opts)
@ -256,8 +280,7 @@ func (g *GroupHandler) Delete(r *http.Request, id string) error {
}
// Patch
// Only supporting replacing the "displayName" attribute.
// Note: Okta does not use PATCH endpoint to update groups (2025/04/01)
// Supporting add/replace/remove operations for "displayName", "externalId", and "members" attributes.
func (g *GroupHandler) Patch(r *http.Request, id string, operations []scim.PatchOperation) (scim.Resource, error) {
idUint, err := extractGroupIDFromValue(id)
if err != nil {
@ -275,53 +298,312 @@ func (g *GroupHandler) Patch(r *http.Request, id string, operations []scim.Patch
}
for _, op := range operations {
if op.Op != "replace" {
if op.Op != scim.PatchOperationAdd && op.Op != scim.PatchOperationReplace && op.Op != scim.PatchOperationRemove {
level.Info(g.logger).Log("msg", "unsupported patch operation", "op", op.Op)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
switch {
case op.Path == nil:
if op.Op == scim.PatchOperationRemove {
level.Info(g.logger).Log("msg", "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
newValues, ok := op.Value.(map[string]interface{})
if !ok {
level.Info(g.logger).Log("msg", "unsupported patch value", "value", op.Value)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
if len(newValues) != 1 {
level.Info(g.logger).Log("msg", "too many patch values", "value", op.Value)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
for k, v := range newValues {
switch k {
case externalIdAttr:
err = g.patchExternalId(op.Op, v, group)
if err != nil {
return scim.Resource{}, err
}
case displayNameAttr:
err = g.patchDisplayName(op.Op, v, group)
if err != nil {
return scim.Resource{}, err
}
case membersAttr:
err = g.patchMembers(r.Context(), op.Op, v, group)
if err != nil {
return scim.Resource{}, err
}
default:
level.Info(g.logger).Log("msg", "unsupported patch value field", "field", k)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
}
displayName, err := getRequiredResource[string](newValues, displayNameAttr)
case op.Path.String() == externalIdAttr:
err = g.patchExternalId(op.Op, op.Value, group)
if err != nil {
level.Info(g.logger).Log("msg", "failed to get active value", "value", op.Value)
return scim.Resource{}, err
}
group.DisplayName = displayName
case op.Path.String() == displayNameAttr:
displayName, ok := op.Value.(string)
if !ok {
level.Error(g.logger).Log("msg", "unsupported 'displayName' patch value", "value", op.Value)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
err = g.patchDisplayName(op.Op, op.Value, group)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == membersAttr:
err = g.patchMembers(r.Context(), op.Op, op.Value, group)
if err != nil {
return scim.Resource{}, err
}
case op.Path.AttributePath.String() == membersAttr:
err = g.patchMembersWithPathFiltering(r.Context(), op, group)
if err != nil {
return scim.Resource{}, err
}
group.DisplayName = displayName
default:
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
}
err = g.ds.ReplaceScimGroup(r.Context(), group)
switch {
case fleet.IsNotFound(err):
level.Info(g.logger).Log("msg", "failed to find group to patch", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(g.logger).Log("msg", "failed to patch group", "id", id, "err", err)
return scim.Resource{}, err
if len(operations) != 0 {
err = g.ds.ReplaceScimGroup(r.Context(), group)
switch {
case fleet.IsNotFound(err):
level.Info(g.logger).Log("msg", "failed to find group to patch", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(g.logger).Log("msg", "failed to patch group", "id", id, "err", err)
return scim.Resource{}, err
}
}
return createGroupResource(group), nil
}
func (g *GroupHandler) patchExternalId(op string, v interface{}, group *fleet.ScimGroup) error {
if op == scim.PatchOperationRemove || v == nil {
group.ExternalID = nil
return nil
}
externalId, ok := v.(string)
if !ok {
level.Info(g.logger).Log("msg", fmt.Sprintf("unsupported '%s' value", externalIdAttr), "value", v)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
}
group.ExternalID = &externalId
return nil
}
func (g *GroupHandler) patchDisplayName(op string, v interface{}, group *fleet.ScimGroup) error {
if op == scim.PatchOperationRemove {
level.Info(g.logger).Log("msg", "cannot remove required attribute", "attribute", displayNameAttr)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
displayName, ok := v.(string)
if !ok {
level.Info(g.logger).Log("msg", fmt.Sprintf("unsupported '%s' value", displayNameAttr), "value", v)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
}
if displayName == "" {
level.Info(g.logger).Log("msg", fmt.Sprintf("'%s' cannot be empty", displayNameAttr), "value", v)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
}
group.DisplayName = displayName
return nil
}
// patchMembers handles add/replace/remove operations for the members attribute
func (g *GroupHandler) patchMembers(ctx context.Context, op string, v interface{}, group *fleet.ScimGroup) error {
if op == scim.PatchOperationRemove {
// Remove all members
group.ScimUsers = []uint{}
return nil
}
// For add and replace operations, we need to extract the member IDs
var membersList []interface{}
// Handle different value formats
switch val := v.(type) {
case []interface{}:
// Direct array of members
membersList = val
case map[string]interface{}:
// Single member as a map
membersList = []interface{}{val}
case []map[string]interface{}:
// Array of member maps
for _, m := range val {
membersList = append(membersList, m)
}
default:
level.Info(g.logger).Log("msg", "unsupported members value format", "value", v)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
}
// Process the members
userIDs := make([]uint, 0, len(membersList))
valueStrings := make([]string, 0, len(membersList))
for _, memberIntf := range membersList {
member, ok := memberIntf.(map[string]interface{})
if !ok {
level.Info(g.logger).Log("msg", "member must be an object", "member", memberIntf)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", memberIntf)})
}
// Get the value attribute which contains the user ID
valueIntf, ok := member["value"]
if !ok || valueIntf == nil {
level.Info(g.logger).Log("msg", "member missing value attribute", "member", member)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", member)})
}
valueStr, ok := valueIntf.(string)
if !ok {
level.Info(g.logger).Log("msg", "member value must be a string", "value", valueIntf)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", valueIntf)})
}
valueStrings = append(valueStrings, valueStr)
// Extract user ID from the value
userID, err := extractUserIDFromValue(valueStr)
if err != nil {
level.Info(g.logger).Log("msg", "invalid user ID format", "value", valueStr, "err", err)
return errors.ScimErrorBadParams([]string{valueStr})
}
userIDs = append(userIDs, userID)
}
// Verify all users exist in a single database call
if len(userIDs) > 0 {
allExist, err := g.ds.ScimUsersExist(ctx, userIDs)
if err != nil {
level.Error(g.logger).Log("msg", "error checking users existence", "err", err)
return err
}
if !allExist {
level.Info(g.logger).Log("msg", "one or more users not found", "userIDs", userIDs)
return errors.ScimErrorBadParams(valueStrings)
}
}
// For add operation, append to existing members
if op == scim.PatchOperationAdd {
// Create a map to track existing user IDs to avoid duplicates
existingUsers := make(map[uint]bool)
for _, id := range group.ScimUsers {
existingUsers[id] = true
}
// Add new users that don't already exist in the group
for _, id := range userIDs {
if !existingUsers[id] {
group.ScimUsers = append(group.ScimUsers, id)
existingUsers[id] = true
}
}
} else {
// For replace operation, replace all members
group.ScimUsers = userIDs
}
return nil
}
// patchMembersWithPathFiltering handles patch operations with path filtering for members
// This supports paths like members[value eq "422"] for add/replace/remove operations
func (g *GroupHandler) patchMembersWithPathFiltering(ctx context.Context, op scim.PatchOperation, group *fleet.ScimGroup) error {
memberID, err := g.getMemberID(op)
if err != nil {
return err
}
// Check if the member exists in the group
memberFound := false
var memberIndex int
for i, id := range group.ScimUsers {
if id == memberID {
memberIndex = i
memberFound = true
break
}
}
// For remove operations, remove the member if found
if op.Op == scim.PatchOperationRemove {
if !memberFound {
level.Info(g.logger).Log("msg", "member not found", "member_id", memberID, "op", fmt.Sprintf("%v", op))
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
group.ScimUsers = append(group.ScimUsers[:memberIndex], group.ScimUsers[memberIndex+1:]...)
return nil
}
// For add operations, add the member if not found
if op.Op == scim.PatchOperationAdd && !memberFound {
// Verify the user exists
userExists, err := g.ds.ScimUsersExist(ctx, []uint{memberID})
if err != nil {
level.Error(g.logger).Log("msg", "error checking user existence", "err", err)
return err
}
if !userExists {
level.Info(g.logger).Log("msg", "user not found", "user_id", memberID)
return errors.ScimErrorBadParams([]string{scimUserID(memberID)})
}
group.ScimUsers = append(group.ScimUsers, memberID)
return nil
}
// For replace operations with a value
if op.Op == scim.PatchOperationReplace {
if !memberFound {
level.Info(g.logger).Log("msg", "member not found for replace operation", "members.value", memberID, "op", fmt.Sprintf("%v", op))
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// If the value is nil or an empty object, remove the member
if op.Value == nil {
group.ScimUsers = append(group.ScimUsers[:memberIndex], group.ScimUsers[memberIndex+1:]...)
return nil
}
// Otherwise, we don't change anything since we're already filtering by the member ID
// and there are no other attributes to modify for a member
return nil
}
return nil
}
// getMemberID extracts the member ID from a path expression like members[value eq "422"]
func (g *GroupHandler) getMemberID(op scim.PatchOperation) (uint, error) {
attrExpression, ok := op.Path.ValueExpression.(*filter.AttributeExpression)
if !ok {
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path)
return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// Only matching by member value (user ID) is supported
if attrExpression.AttributePath.String() != valueAttr || attrExpression.Operator != filter.EQ {
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path, "expression", attrExpression.AttributePath.String())
return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
memberIDStr, ok := attrExpression.CompareValue.(string)
if !ok {
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path, "compare_value", attrExpression.CompareValue)
return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// Extract user ID from the value
userID, err := extractUserIDFromValue(memberIDStr)
if err != nil {
level.Info(g.logger).Log("msg", "invalid user ID format", "value", memberIDStr, "err", err)
return 0, errors.ScimErrorBadParams([]string{memberIDStr})
}
return userID, nil
}
func scimGroupID(groupID uint) string {
return fmt.Sprintf("group-%d", groupID)
}

View file

@ -33,6 +33,8 @@ func RegisterSCIM(
config := scim.ServiceProviderConfig{
DocumentationURI: optional.NewString("https://fleetdm.com/docs/get-started/why-fleet"),
MaxResults: maxResults,
SupportFiltering: true,
SupportPatch: true,
}
// The common attributes are id, externalId, and meta.
@ -136,18 +138,14 @@ func RegisterSCIM(
Mutability: schema.AttributeMutabilityImmutable(),
Name: "value",
}),
schema.SimpleReferenceParams(schema.ReferenceParams{
Description: optional.NewString("The URI corresponding to a SCIM resource that is a member of this Group."),
Mutability: schema.AttributeMutabilityImmutable(),
Name: "$ref",
ReferenceTypes: []schema.AttributeReferenceType{"User"},
}),
schema.SimpleStringParams(schema.StringParams{
CanonicalValues: []string{"User"},
Description: optional.NewString("A label indicating the type of resource, e.g., 'User' or 'Group'."),
Mutability: schema.AttributeMutabilityImmutable(),
Name: "type",
}),
// Note (2025/05/06): Microsoft does not properly support $ref attribute on group members
// https://learn.microsoft.com/en-us/answers/questions/1457148/scim-validator-patch-group-remove-member-test-comp
},
}),
},

View file

@ -406,9 +406,7 @@ func (u *UserHandler) Delete(r *http.Request, id string) error {
return nil
}
// Patch
// Okta only requires patching the "active" attribute:
// https://developer.okta.com/docs/api/openapi/okta-scim/guides/scim-20/#update-a-specific-user-patch
// Patch - https://datatracker.ietf.org/doc/html/rfc7644#section-3.5.2
func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchOperation) (scim.Resource, error) {
idUint, err := extractUserIDFromValue(id)
if err != nil {
@ -426,12 +424,17 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
}
for _, op := range operations {
if op.Op != "replace" {
if op.Op != scim.PatchOperationAdd && op.Op != scim.PatchOperationReplace && op.Op != scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "unsupported patch operation", "op", op.Op)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
switch {
// If path is not specified, we look for the path in the value attribute.
case op.Path == nil:
if op.Op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
newValues, ok := op.Value.(map[string]interface{})
if !ok {
level.Info(u.logger).Log("msg", "unsupported patch value", "value", op.Value)
@ -439,27 +442,28 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
}
for k, v := range newValues {
switch k {
case externalIdAttr:
err = u.patchExternalId(op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
case userNameAttr:
err = u.patchUserName(v, user)
err = u.patchUserName(op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
case activeAttr:
if v == nil {
user.Active = nil
continue
}
err = u.patchActive(v, user)
err = u.patchActive(op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
case nameAttr + "." + givenNameAttr:
err = u.patchGivenName(v, user)
err = u.patchGivenName(op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
case nameAttr + "." + familyNameAttr:
err = u.patchFamilyName(v, user)
err = u.patchFamilyName(op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
@ -478,27 +482,28 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
}
case op.Path.String() == externalIdAttr:
err = u.patchExternalId(op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == userNameAttr:
err = u.patchUserName(op.Value, user)
err = u.patchUserName(op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == activeAttr:
if op.Value == nil {
user.Active = nil
continue
}
err = u.patchActive(op.Value, user)
err = u.patchActive(op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == nameAttr+"."+givenNameAttr:
err = u.patchGivenName(op.Value, user)
err = u.patchGivenName(op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == nameAttr+"."+familyNameAttr:
err = u.patchFamilyName(op.Value, user)
err = u.patchFamilyName(op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
@ -513,84 +518,10 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
return scim.Resource{}, err
}
case op.Path.AttributePath.String() == emailsAttr:
emailType, err := u.getEmailType(op)
err = u.patchEmailsWithPathFiltering(op, user)
if err != nil {
return scim.Resource{}, err
}
emailFound := false
var emailIndex int
for i, email := range user.Emails {
if email.Type != nil && *email.Type == emailType {
emailIndex = i
emailFound = true
break
}
}
if !emailFound {
level.Info(u.logger).Log("msg", "email not found", "email_type", emailType)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
if op.Path.SubAttribute == nil {
// The value for emails comes in as an array.
userEmails, ok := op.Value.([]interface{})
if !ok {
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
if len(userEmails) == 0 {
user.Emails = slices.Delete(user.Emails, emailIndex, emailIndex)
continue
}
if len(userEmails) != 1 {
level.Info(u.logger).Log("msg", "only 1 email should be present for replacement", "emails", userEmails)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
userEmail, err := u.extractEmail(userEmails[0], op)
if err != nil {
return scim.Resource{}, err
}
// If setting primary to true, then unset true from other emails
if userEmail.Primary != nil && *userEmail.Primary {
clearPrimaryFlagFromEmails(user)
}
user.Emails[emailIndex] = userEmail
continue
}
switch *op.Path.SubAttribute {
case primaryAttr:
if op.Value == nil {
user.Emails[emailIndex].Primary = nil
continue
}
primary, err := getConcreteType[bool](u, op.Value, primaryAttr)
if err != nil {
return scim.Resource{}, err
}
// If setting primary to true, then unset true from other emails
if primary {
clearPrimaryFlagFromEmails(user)
}
user.Emails[emailIndex].Primary = &primary
case valueAttr:
value, err := getConcreteType[string](u, op.Value, valueAttr)
if err != nil {
return scim.Resource{}, err
}
user.Emails[emailIndex].Email = value
case typeAttr:
if op.Value == nil {
user.Emails[emailIndex].Type = nil
continue
}
newEmailType, err := getConcreteType[string](u, op.Value, typeAttr)
if err != nil {
return scim.Resource{}, err
}
user.Emails[emailIndex].Type = &newEmailType
default:
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
default:
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
@ -612,6 +543,146 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
return createUserResource(user), nil
}
func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user *fleet.ScimUser) error {
emailType, err := u.getEmailType(op)
if err != nil {
return err
}
emailFound := false
var emailIndex int
for i, email := range user.Emails {
if email.Type != nil && *email.Type == emailType {
emailIndex = i
emailFound = true
break
}
}
if !emailFound && op.Op != scim.PatchOperationAdd {
level.Info(u.logger).Log("msg", "email not found", "email_type", emailType, "op", fmt.Sprintf("%v", op))
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
if op.Path.SubAttribute == nil {
if op.Op == scim.PatchOperationRemove {
user.Emails = slices.Delete(user.Emails, emailIndex, emailIndex+1)
return nil
}
// For add and replace operations, we need to extract the emails
var emailsList []interface{}
// Handle different value formats
switch val := op.Value.(type) {
case []interface{}:
// Direct array of members
emailsList = val
case map[string]interface{}:
// Single member as a map
emailsList = []interface{}{val}
default:
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
switch op.Op {
case scim.PatchOperationReplace:
if len(emailsList) == 0 {
user.Emails = slices.Delete(user.Emails, emailIndex, emailIndex+1)
return nil
}
if len(emailsList) != 1 {
level.Info(u.logger).Log("msg", "only 1 email should be present for replacement", "emails", emailsList)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
userEmail, err := u.extractEmail(emailsList[0], op)
if err != nil {
return err
}
// If setting primary to true, then unset true from other emails
if userEmail.Primary != nil && *userEmail.Primary {
clearPrimaryFlagFromEmails(user)
}
user.Emails[emailIndex] = userEmail
case scim.PatchOperationAdd:
if len(emailsList) == 0 {
level.Info(u.logger).Log("msg", "no emails provided to add", "emails", emailsList)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
var newEmails []fleet.ScimUserEmail
for e := range emailsList {
userEmail, err := u.extractEmail(emailsList[e], op)
if err != nil {
return err
}
userEmail.Type = &emailType
newEmails = append(newEmails, userEmail)
}
primaryExists, err := u.checkEmailPrimary(newEmails)
if err != nil {
return err
}
if primaryExists {
clearPrimaryFlagFromEmails(user)
}
user.Emails = append(user.Emails, newEmails...)
}
return nil
}
if op.Op == scim.PatchOperationAdd && !emailFound {
user.Emails = append(user.Emails, fleet.ScimUserEmail{
Type: ptr.String(emailType),
})
emailIndex = len(user.Emails) - 1
}
switch *op.Path.SubAttribute {
case primaryAttr:
if op.Op == scim.PatchOperationRemove {
user.Emails[emailIndex].Primary = nil
return nil
}
if op.Value == nil {
user.Emails[emailIndex].Primary = nil
return nil
}
primary, err := getConcreteType[bool](u, op.Value, primaryAttr)
if err != nil {
return err
}
// If setting primary to true, then unset true from other emails
if primary {
clearPrimaryFlagFromEmails(user)
}
user.Emails[emailIndex].Primary = &primary
case valueAttr:
if op.Op == scim.PatchOperationRemove {
// The operation of removing an email value doesn't make sense, but we allow it.
user.Emails[emailIndex].Email = ""
return nil
}
value, err := getConcreteType[string](u, op.Value, valueAttr)
if err != nil {
return err
}
user.Emails[emailIndex].Email = value
case typeAttr:
if op.Op == scim.PatchOperationRemove {
user.Emails[emailIndex].Type = nil
return nil
}
if op.Value == nil {
user.Emails[emailIndex].Type = nil
return nil
}
newEmailType, err := getConcreteType[string](u, op.Value, typeAttr)
if err != nil {
return err
}
user.Emails[emailIndex].Type = &newEmailType
default:
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
return nil
}
func (u *UserHandler) getEmailType(op scim.PatchOperation) (string, error) {
attrExpression, ok := op.Path.ValueExpression.(*filter.AttributeExpression)
if !ok {
@ -641,7 +712,11 @@ func getConcreteType[T string | bool](u *UserHandler, v interface{}, name string
return concreteType, nil
}
func (u *UserHandler) patchFamilyName(v interface{}, user *fleet.ScimUser) error {
func (u *UserHandler) patchFamilyName(op string, v interface{}, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr+"."+familyNameAttr)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
familyName, err := getConcreteType[string](u, v, nameAttr+"."+familyNameAttr)
if err != nil {
return err
@ -650,7 +725,11 @@ func (u *UserHandler) patchFamilyName(v interface{}, user *fleet.ScimUser) error
return nil
}
func (u *UserHandler) patchGivenName(v interface{}, user *fleet.ScimUser) error {
func (u *UserHandler) patchGivenName(op string, v interface{}, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr+"."+givenNameAttr)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
givenName, err := getConcreteType[string](u, v, nameAttr+"."+givenNameAttr)
if err != nil {
return err
@ -659,7 +738,11 @@ func (u *UserHandler) patchGivenName(v interface{}, user *fleet.ScimUser) error
return nil
}
func (u *UserHandler) patchActive(v interface{}, user *fleet.ScimUser) error {
func (u *UserHandler) patchActive(op string, v interface{}, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove || v == nil {
user.Active = nil
return nil
}
active, err := getConcreteType[bool](u, v, activeAttr)
if err != nil {
return err
@ -668,7 +751,24 @@ func (u *UserHandler) patchActive(v interface{}, user *fleet.ScimUser) error {
return nil
}
func (u *UserHandler) patchUserName(v interface{}, user *fleet.ScimUser) error {
func (u *UserHandler) patchExternalId(op string, v interface{}, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove || v == nil {
user.ExternalID = nil
return nil
}
externalId, err := getConcreteType[string](u, v, externalIdAttr)
if err != nil {
return err
}
user.ExternalID = ptr.String(externalId)
return nil
}
func (u *UserHandler) patchUserName(op string, v interface{}, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", userNameAttr)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
userName, err := getConcreteType[string](u, v, userNameAttr)
if err != nil {
return err
@ -690,43 +790,69 @@ func clearPrimaryFlagFromEmails(user *fleet.ScimUser) {
}
func (u *UserHandler) patchEmails(v interface{}, op scim.PatchOperation, user *fleet.ScimUser) error {
emailsValue, ok := v.([]interface{})
if !ok {
if op.Op == scim.PatchOperationRemove {
user.Emails = nil
return nil
}
// For add and replace operations, we need to extract the emails
var emailsList []interface{}
// Handle different value formats
switch val := v.(type) {
case []interface{}:
// Direct array of members
emailsList = val
case map[string]interface{}:
// Single member as a map
emailsList = []interface{}{val}
default:
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
if op.Op == scim.PatchOperationAdd && len(emailsList) == 0 {
level.Info(u.logger).Log("msg", "no emails provided to add", "emails", emailsList)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// Convert the emails to the expected format
userEmails := make([]fleet.ScimUserEmail, 0, len(emailsValue))
for _, emailIntf := range emailsValue {
userEmails := make([]fleet.ScimUserEmail, 0, len(emailsList))
for _, emailIntf := range emailsList {
userEmail, err := u.extractEmail(emailIntf, op)
if err != nil {
return err
}
userEmails = append(userEmails, userEmail)
}
err := u.checkEmailPrimary(userEmails)
primaryExists, err := u.checkEmailPrimary(userEmails)
if err != nil {
return err
}
if op.Op == scim.PatchOperationAdd {
if primaryExists {
// Clear the primary flag from current emails because we are merging the two email lists and a new email has that flag.
clearPrimaryFlagFromEmails(user)
}
userEmails = append(user.Emails, userEmails...)
}
user.Emails = userEmails
return nil
}
// checkEmailPrimary ensures at most one email is marked as primary
func (u *UserHandler) checkEmailPrimary(userEmails []fleet.ScimUserEmail) error {
func (u *UserHandler) checkEmailPrimary(userEmails []fleet.ScimUserEmail) (bool, error) {
primaryEmailCount := 0
for _, email := range userEmails {
if email.Primary != nil && *email.Primary {
primaryEmailCount++
if primaryEmailCount > 1 {
level.Info(u.logger).Log("msg", "multiple primary emails found")
return errors.ScimErrorBadParams([]string{"Only one email can be marked as primary"})
return false, errors.ScimErrorBadParams([]string{"Only one email can be marked as primary"})
}
}
}
return nil
return primaryEmailCount > 0, nil
}
func (u *UserHandler) extractEmail(emailIntf interface{}, op scim.PatchOperation) (fleet.ScimUserEmail, error) {
@ -770,6 +896,10 @@ func (u *UserHandler) extractEmail(emailIntf interface{}, op scim.PatchOperation
}
func (u *UserHandler) patchName(v interface{}, op scim.PatchOperation, user *fleet.ScimUser) error {
if op.Op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
name, ok := v.(map[string]interface{})
if !ok {
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", nameAttr), "value", op.Value)

View file

@ -222,6 +222,24 @@ func (ds *Datastore) ReplaceScimUser(ctx context.Context, user *fleet.ScimUser)
return err
}
// Validate that at most one email is marked as primary
primaryCount := 0
for _, email := range user.Emails {
if email.Primary != nil && *email.Primary {
primaryCount++
}
}
if primaryCount > 1 {
return ctxerr.New(ctx, "only one email can be marked as primary")
}
// Get current emails and check if they need to be updated
currentEmails, err := ds.getScimUserEmails(ctx, user.ID)
if err != nil {
return err
}
emailsNeedUpdate := emailsRequireUpdate(currentEmails, user.Emails)
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
// load the username before updating the user, to check if it changed
var oldUsername string
@ -265,24 +283,23 @@ func (ds *Datastore) ReplaceScimUser(ctx context.Context, user *fleet.ScimUser)
}
usernameChanged := oldUsername != user.UserName
// We assume that email is not blank/null.
// However, we do not assume that email/type are unique for a user. To keep the code simple, we:
// 1. Delete all existing emails
// 2. Insert all new emails
// This is less efficient and can be optimized if we notice a load on these tables in production.
// Only update emails if they've changed
if emailsNeedUpdate {
// We assume that email is not blank/null.
// However, we do not assume that email/type are unique for a user. To keep the code simple, we:
// 1. Delete all existing emails
// 2. Insert all new emails
// This is less efficient and can be optimized if we notice a load on these tables in production.
// TODO: Check if emails need to be updated at all. If so, update the user updated_at timestamp if emails have been updated
// TODO: Check that only 1 email is primary
const deleteEmailsQuery = `DELETE FROM scim_user_emails WHERE scim_user_id = ?`
_, err = tx.ExecContext(ctx, deleteEmailsQuery, user.ID)
if err != nil {
return ctxerr.Wrap(ctx, err, "delete scim user emails")
}
err = insertEmails(ctx, tx, user)
if err != nil {
return err
const deleteEmailsQuery = `DELETE FROM scim_user_emails WHERE scim_user_id = ?`
_, err = tx.ExecContext(ctx, deleteEmailsQuery, user.ID)
if err != nil {
return ctxerr.Wrap(ctx, err, "delete scim user emails")
}
err = insertEmails(ctx, tx, user)
if err != nil {
return err
}
}
// Get the user's groups
@ -867,7 +884,7 @@ func (ds *Datastore) DeleteScimGroup(ctx context.Context, id uint) error {
}
// ListScimGroups retrieves a list of SCIM groups with pagination
func (ds *Datastore) ListScimGroups(ctx context.Context, opts fleet.ScimListOptions) (groups []fleet.ScimGroup, totalResults uint, err error) {
func (ds *Datastore) ListScimGroups(ctx context.Context, opts fleet.ScimGroupsListOptions) (groups []fleet.ScimGroup, totalResults uint, err error) {
// Default pagination values if not provided
if opts.StartIndex == 0 {
opts.StartIndex = 1
@ -883,16 +900,25 @@ func (ds *Datastore) ListScimGroups(ctx context.Context, opts fleet.ScimListOpti
FROM scim_groups
`
// Add where clause based on filters
var whereClause string
var params []interface{}
if opts.DisplayNameFilter != nil {
whereClause = " WHERE scim_groups.display_name = ?"
params = append(params, *opts.DisplayNameFilter)
}
// First, get the total count without pagination
countQuery := "SELECT COUNT(DISTINCT id) FROM (" + baseQuery + ") AS filtered_groups"
err = sqlx.GetContext(ctx, ds.reader(ctx), &totalResults, countQuery)
countQuery := "SELECT COUNT(DISTINCT id) FROM (" + baseQuery + whereClause + ") AS filtered_groups"
err = sqlx.GetContext(ctx, ds.reader(ctx), &totalResults, countQuery, params...)
if err != nil {
return nil, 0, ctxerr.Wrap(ctx, err, "count total scim groups")
}
// Add pagination to the main query
query := baseQuery + " ORDER BY scim_groups.id LIMIT ? OFFSET ?"
params := []interface{}{opts.PerPage, opts.StartIndex - 1}
query := baseQuery + whereClause + " ORDER BY scim_groups.id LIMIT ? OFFSET ?"
params = append(params, opts.PerPage, opts.StartIndex-1)
// Execute the query
err = sqlx.SelectContext(ctx, ds.reader(ctx), &groups, query, params...)
@ -1208,3 +1234,76 @@ func triggerResendProfilesUsingVariables(ctx context.Context, tx sqlx.ExtContext
}
return nil
}
// emailsRequireUpdate compares two slices of emails and returns true if they are different
// and require an update in the database.
func emailsRequireUpdate(currentEmails, newEmails []fleet.ScimUserEmail) bool {
if len(currentEmails) != len(newEmails) {
return true
}
// Create maps for efficient comparison
currentEmailMap := make(map[string]fleet.ScimUserEmail)
for i := range currentEmails {
key := currentEmails[i].GenerateComparisonKey()
currentEmailMap[key] = currentEmails[i]
}
// Check if all new emails exist in current emails with the same attributes
for i := range newEmails {
key := newEmails[i].GenerateComparisonKey()
if _, exists := currentEmailMap[key]; !exists {
return true
}
}
return false
}
// ScimUsersExist checks if all the provided SCIM user IDs exist in the datastore
// If the slice is empty, it returns true
// This method processes IDs in batches to handle large numbers of IDs efficiently
func (ds *Datastore) ScimUsersExist(ctx context.Context, ids []uint) (bool, error) {
if len(ids) == 0 {
return true, nil
}
// Create a map to track which IDs we've found
foundIDs := make(map[uint]bool, len(ids))
batchSize := 10000
err := common_mysql.BatchProcessSimple(ids, batchSize, func(batchIDs []uint) error {
query, args, err := sqlx.In(`
SELECT id
FROM scim_users
WHERE id IN (?)
`, batchIDs)
if err != nil {
return ctxerr.Wrap(ctx, err, "prepare scim users exist batch query")
}
var foundBatchIDs []uint
err = sqlx.SelectContext(ctx, ds.reader(ctx), &foundBatchIDs, query, args...)
if err != nil {
return ctxerr.Wrap(ctx, err, "check if scim users exist in batch")
}
// Mark found IDs
for _, id := range foundBatchIDs {
foundIDs[id] = true
}
return nil
})
if err != nil {
return false, err
}
// Check if all IDs were found
for _, id := range ids {
if !foundIDs[id] {
return false, nil
}
}
return true, nil
}

View file

@ -29,6 +29,7 @@ func TestScim(t *testing.T) {
{"ScimUserByUserNameOrEmail", testScimUserByUserNameOrEmail},
{"ScimUserByHostID", testScimUserByHostID},
{"ReplaceScimUser", testReplaceScimUser},
{"ReplaceScimUserEmails", testReplaceScimUserEmails},
{"ReplaceScimUserValidation", testScimUserReplaceValidation},
{"DeleteScimUser", testDeleteScimUser},
{"ListScimUsers", testListScimUsers},
@ -41,6 +42,7 @@ func TestScim(t *testing.T) {
{"DeleteScimGroup", testDeleteScimGroup},
{"ListScimGroups", testListScimGroups},
{"ScimLastRequest", testScimLastRequest},
{"ScimUsersExist", testScimUsersExist},
{"TriggerResendIdPProfiles", testTriggerResendIdPProfiles},
{"TriggerResendIdPProfilesOnTeam", testTriggerResendIdPProfilesOnTeam},
}
@ -423,6 +425,171 @@ func testReplaceScimUser(t *testing.T, ds *Datastore) {
assert.True(t, fleet.IsNotFound(err))
}
func testReplaceScimUserEmails(t *testing.T, ds *Datastore) {
// Create a test user
user := fleet.ScimUser{
UserName: "email-test-user",
ExternalID: ptr.String("ext-email-123"),
GivenName: ptr.String("Email"),
FamilyName: ptr.String("Test"),
Active: ptr.Bool(true),
Emails: []fleet.ScimUserEmail{
{
Email: "original.email@example.com",
Primary: ptr.Bool(true),
Type: ptr.String("work"),
},
},
}
var err error
user.ID, err = ds.CreateScimUser(t.Context(), &user)
require.Nil(t, err)
// Smoke test email optimization - replacing with the same emails should not update emails
// First, get the current user to have a reference point
currentUser, err := ds.ScimUserByID(t.Context(), user.ID)
require.NoError(t, err)
// Create a copy of the user with the same emails
sameEmailsUser := fleet.ScimUser{
ID: user.ID,
UserName: "multi-update@example.com",
ExternalID: ptr.String("ext-replace-456"),
GivenName: ptr.String("Multiple"),
FamilyName: ptr.String("Updates"),
Active: ptr.Bool(true),
Emails: currentUser.Emails, // Same emails as current user
}
// Replace the user
err = ds.ReplaceScimUser(t.Context(), &sameEmailsUser)
require.NoError(t, err)
// Verify the user was updated correctly but emails remain the same
sameEmailsResult, err := ds.ScimUserByID(t.Context(), user.ID)
require.NoError(t, err)
assert.Equal(t, sameEmailsUser.UserName, sameEmailsResult.UserName)
assert.Equal(t, sameEmailsUser.ExternalID, sameEmailsResult.ExternalID)
assert.Equal(t, sameEmailsUser.GivenName, sameEmailsResult.GivenName)
assert.Equal(t, sameEmailsUser.FamilyName, sameEmailsResult.FamilyName)
assert.Equal(t, sameEmailsUser.Active, sameEmailsResult.Active)
// Verify emails are the same as before
assert.Equal(t, len(currentUser.Emails), len(sameEmailsResult.Emails))
for i := range currentUser.Emails {
assert.Equal(t, currentUser.Emails[i].Email, sameEmailsResult.Emails[i].Email)
assert.Equal(t, currentUser.Emails[i].Type, sameEmailsResult.Emails[i].Type)
assert.Equal(t, currentUser.Emails[i].Primary, sameEmailsResult.Emails[i].Primary)
}
// Test validation for multiple primary emails
multiPrimaryUser := fleet.ScimUser{
ID: user.ID,
UserName: "multi-primary@example.com",
ExternalID: ptr.String("ext-multi-primary"),
GivenName: ptr.String("Multi"),
FamilyName: ptr.String("Primary"),
Active: ptr.Bool(true),
Emails: []fleet.ScimUserEmail{
{
Email: "primary1@example.com",
Primary: ptr.Bool(true), // First primary
Type: ptr.String("work"),
},
{
Email: "primary2@example.com",
Primary: ptr.Bool(true), // Second primary - should cause validation error
Type: ptr.String("home"),
},
},
}
// This should fail with a validation error
err = ds.ReplaceScimUser(t.Context(), &multiPrimaryUser)
assert.Error(t, err)
assert.Contains(t, err.Error(), "only one email can be marked as primary")
// Test email comparison behavior with different combinations of nil/non-nil fields
// First, create a user with an email that has all fields set
userWithAllFields := fleet.ScimUser{
ID: user.ID,
UserName: "all-fields@example.com",
ExternalID: ptr.String("ext-all-fields"),
GivenName: ptr.String("All"),
FamilyName: ptr.String("Fields"),
Active: ptr.Bool(true),
Emails: []fleet.ScimUserEmail{
{
Email: "all-fields@example.com",
Primary: ptr.Bool(true),
Type: ptr.String("work"),
},
},
}
err = ds.ReplaceScimUser(t.Context(), &userWithAllFields)
require.NoError(t, err)
// Now create a user with the same email but with nil Primary field
userWithNilPrimary := fleet.ScimUser{
ID: user.ID,
UserName: "all-fields@example.com",
ExternalID: ptr.String("ext-all-fields"),
GivenName: ptr.String("All"),
FamilyName: ptr.String("Fields"),
Active: ptr.Bool(true),
Emails: []fleet.ScimUserEmail{
{
Email: "all-fields@example.com",
Primary: nil, // Changed from true to nil
Type: ptr.String("work"),
},
},
}
// This should update the emails since the Primary field changed
err = ds.ReplaceScimUser(t.Context(), &userWithNilPrimary)
require.NoError(t, err)
// Verify the email was updated
var nilPrimaryUser *fleet.ScimUser
nilPrimaryUser, err = ds.ScimUserByID(t.Context(), user.ID)
require.NoError(t, err)
require.Len(t, nilPrimaryUser.Emails, 1)
assert.Equal(t, "all-fields@example.com", nilPrimaryUser.Emails[0].Email)
assert.Nil(t, nilPrimaryUser.Emails[0].Primary, "Primary field should be nil")
// Now create a user with the same email but with nil Type field
userWithNilType := fleet.ScimUser{
ID: user.ID,
UserName: "all-fields@example.com",
ExternalID: ptr.String("ext-all-fields"),
GivenName: ptr.String("All"),
FamilyName: ptr.String("Fields"),
Active: ptr.Bool(true),
Emails: []fleet.ScimUserEmail{
{
Email: "all-fields@example.com",
Primary: nil,
Type: nil, // Changed from "work" to nil
},
},
}
// This should update the emails since the Type field changed
err = ds.ReplaceScimUser(t.Context(), &userWithNilType)
require.NoError(t, err)
// Verify the email was updated
var nilTypeUser *fleet.ScimUser
nilTypeUser, err = ds.ScimUserByID(t.Context(), user.ID)
require.NoError(t, err)
require.Len(t, nilTypeUser.Emails, 1)
assert.Equal(t, "all-fields@example.com", nilTypeUser.Emails[0].Email)
assert.Nil(t, nilTypeUser.Emails[0].Type, "Type field should be nil")
}
func testDeleteScimUser(t *testing.T, ds *Datastore) {
// Create a test user
user := fleet.ScimUser{
@ -1047,9 +1214,11 @@ func testListScimGroups(t *testing.T, ds *Datastore) {
}
// Test 1: List all groups
allGroups, totalResults, err := ds.ListScimGroups(t.Context(), fleet.ScimListOptions{
StartIndex: 1,
PerPage: 10,
allGroups, totalResults, err := ds.ListScimGroups(t.Context(), fleet.ScimGroupsListOptions{
ScimListOptions: fleet.ScimListOptions{
StartIndex: 1,
PerPage: 10,
},
})
require.Nil(t, err)
assert.GreaterOrEqual(t, len(allGroups), 3) // There might be other groups from previous tests
@ -1068,18 +1237,22 @@ func testListScimGroups(t *testing.T, ds *Datastore) {
assert.Equal(t, 3, foundGroups)
// Test 2: Pagination - first page with 2 items
page1Groups, totalPage1, err := ds.ListScimGroups(t.Context(), fleet.ScimListOptions{
StartIndex: 1,
PerPage: 2,
page1Groups, totalPage1, err := ds.ListScimGroups(t.Context(), fleet.ScimGroupsListOptions{
ScimListOptions: fleet.ScimListOptions{
StartIndex: 1,
PerPage: 2,
},
})
require.Nil(t, err)
assert.Equal(t, 2, len(page1Groups))
assert.GreaterOrEqual(t, totalPage1, uint(3)) // Total should be at least 3
// Test 3: Pagination - second page with 2 items
page2Groups, totalPage2, err := ds.ListScimGroups(t.Context(), fleet.ScimListOptions{
StartIndex: 3, // StartIndex is 1-based, so for the second page with 2 items per page, we start at index 3
PerPage: 2,
page2Groups, totalPage2, err := ds.ListScimGroups(t.Context(), fleet.ScimGroupsListOptions{
ScimListOptions: fleet.ScimListOptions{
StartIndex: 3, // StartIndex is 1-based, so for the second page with 2 items per page, we start at index 3
PerPage: 2,
},
})
require.Nil(t, err)
assert.GreaterOrEqual(t, len(page2Groups), 1) // At least 1 item on the second page
@ -1091,6 +1264,33 @@ func testListScimGroups(t *testing.T, ds *Datastore) {
assert.NotEqual(t, p1Group.ID, p2Group.ID, "Groups should not appear on multiple pages")
}
}
// Test 4: Filter by display name
displayName := "List Test Group 2"
filteredGroups, totalFilteredResults, err := ds.ListScimGroups(t.Context(), fleet.ScimGroupsListOptions{
ScimListOptions: fleet.ScimListOptions{
StartIndex: 1,
PerPage: 10,
},
DisplayNameFilter: &displayName,
})
require.Nil(t, err)
assert.Equal(t, 1, len(filteredGroups), "Should find exactly one group with the specified display name")
assert.Equal(t, uint(1), totalFilteredResults)
assert.Equal(t, displayName, filteredGroups[0].DisplayName)
// Test 5: Filter by non-existent display name
nonExistentName := "Non-Existent Group"
emptyResults, totalEmptyResults, err := ds.ListScimGroups(t.Context(), fleet.ScimGroupsListOptions{
ScimListOptions: fleet.ScimListOptions{
StartIndex: 1,
PerPage: 10,
},
DisplayNameFilter: &nonExistentName,
})
require.Nil(t, err)
assert.Empty(t, emptyResults, "Should find no groups with a non-existent display name")
assert.Equal(t, uint(0), totalEmptyResults)
}
func testScimUserCreateValidation(t *testing.T, ds *Datastore) {
@ -1967,3 +2167,60 @@ func forceSetHostProfileStatus(t *testing.T, ds *Datastore, hostUUID string, pro
return err
})
}
func testScimUsersExist(t *testing.T, ds *Datastore) {
// Create test users
users := createTestScimUsers(t, ds)
userIDs := make([]uint, len(users))
for i, user := range users {
userIDs[i] = user.ID
}
// Test 1: Empty slice should return true
exist, err := ds.ScimUsersExist(t.Context(), []uint{})
require.NoError(t, err)
assert.True(t, exist, "Empty slice should return true")
// Test 2: All existing users should return true
exist, err = ds.ScimUsersExist(t.Context(), userIDs)
require.NoError(t, err)
assert.True(t, exist, "All existing users should return true")
// Test 3: Mix of existing and non-existing users should return false
nonExistentIDs := userIDs
nonExistentIDs = append(nonExistentIDs, 99999)
exist, err = ds.ScimUsersExist(t.Context(), nonExistentIDs)
require.NoError(t, err)
assert.False(t, exist, "Mix of existing and non-existing users should return false")
// Test 4: Only non-existing users should return false
exist, err = ds.ScimUsersExist(t.Context(), []uint{99999, 100000})
require.NoError(t, err)
assert.False(t, exist, "Only non-existing users should return false")
// Test 5: Test with a large number of IDs to verify batching works
// First, create a large number of test users
largeUserIDs := make([]uint, 0, 25000)
largeUserIDs = append(largeUserIDs, userIDs...) // Add existing users
// Add some non-existent IDs to test batching with mixed results
for i := 0; i < 24990; i++ {
largeUserIDs = append(largeUserIDs, uint(1000000)+uint(i)) // nolint:gosec // dismiss G115 integer overflow
}
exist, err = ds.ScimUsersExist(t.Context(), largeUserIDs)
require.NoError(t, err)
assert.False(t, exist, "Large batch with non-existing users should return false")
// Test 6: Test with a large number of existing IDs
// This is a bit tricky to test thoroughly without creating thousands of users,
// so we'll just verify the function handles a large slice without errors
largeExistingIDs := make([]uint, 0, 25000)
for i := 0; i < 25000; i++ {
largeExistingIDs = append(largeExistingIDs, userIDs[i%len(userIDs)])
}
exist, err = ds.ScimUsersExist(t.Context(), largeExistingIDs)
require.NoError(t, err)
assert.True(t, exist, "Large batch with only existing users should return true")
}

View file

@ -2067,6 +2067,9 @@ type Datastore interface {
ScimUserByUserNameOrEmail(ctx context.Context, userName string, email string) (*ScimUser, error)
// ScimUserByHostID retrieves a SCIM user associated with a host ID
ScimUserByHostID(ctx context.Context, hostID uint) (*ScimUser, error)
// ScimUsersExist checks if all the provided SCIM user IDs exist in the datastore
// If the slice is empty, it returns true
ScimUsersExist(ctx context.Context, ids []uint) (bool, error)
// ReplaceScimUser replaces an existing SCIM user in the database
ReplaceScimUser(ctx context.Context, user *ScimUser) error
// DeleteScimUser deletes a SCIM user from the database
@ -2084,7 +2087,7 @@ type Datastore interface {
// DeleteScimGroup deletes a SCIM group from the database
DeleteScimGroup(ctx context.Context, id uint) error
// ListScimGroups retrieves a list of SCIM groups with pagination
ListScimGroups(ctx context.Context, opts ScimListOptions) (groups []ScimGroup, totalResults uint, err error)
ListScimGroups(ctx context.Context, opts ScimGroupsListOptions) (groups []ScimGroup, totalResults uint, err error)
// ScimLastRequest retrieves the last SCIM request info
ScimLastRequest(ctx context.Context) (*ScimLastRequest, error)
// UpdateScimLastRequest updates the last SCIM request info

View file

@ -48,6 +48,28 @@ type ScimUserEmail struct {
Type *string `db:"type"`
}
// GenerateComparisonKey generates a unique string representation of the email
// that can be used for comparison, properly handling nil values.
func (e ScimUserEmail) GenerateComparisonKey() string {
// Handle Type field which can be nil
typeValue := "nil"
if e.Type != nil {
typeValue = *e.Type
}
// Handle Primary field which can be nil
primaryValue := "nil"
if e.Primary != nil {
if *e.Primary {
primaryValue = "true"
} else {
primaryValue = "false"
}
}
return e.Email + ":" + typeValue + ":" + primaryValue
}
type ScimListOptions struct {
// 1-based index of the first result to return (must be positive integer)
StartIndex uint
@ -69,6 +91,13 @@ type ScimUsersListOptions struct {
EmailValueFilter *string
}
type ScimGroupsListOptions struct {
ScimListOptions
// DisplayNameFilter filters by displayName
DisplayNameFilter *string
}
type ScimGroup struct {
ID uint `db:"id"`
ExternalID *string `db:"external_id"`

View file

@ -1320,6 +1320,8 @@ type ScimUserByUserNameOrEmailFunc func(ctx context.Context, userName string, em
type ScimUserByHostIDFunc func(ctx context.Context, hostID uint) (*fleet.ScimUser, error)
type ScimUsersExistFunc func(ctx context.Context, ids []uint) (bool, error)
type ReplaceScimUserFunc func(ctx context.Context, user *fleet.ScimUser) error
type DeleteScimUserFunc func(ctx context.Context, id uint) error
@ -1336,7 +1338,7 @@ type ReplaceScimGroupFunc func(ctx context.Context, group *fleet.ScimGroup) erro
type DeleteScimGroupFunc func(ctx context.Context, id uint) error
type ListScimGroupsFunc func(ctx context.Context, opts fleet.ScimListOptions) (groups []fleet.ScimGroup, totalResults uint, err error)
type ListScimGroupsFunc func(ctx context.Context, opts fleet.ScimGroupsListOptions) (groups []fleet.ScimGroup, totalResults uint, err error)
type ScimLastRequestFunc func(ctx context.Context) (*fleet.ScimLastRequest, error)
@ -3290,6 +3292,9 @@ type DataStore struct {
ScimUserByHostIDFunc ScimUserByHostIDFunc
ScimUserByHostIDFuncInvoked bool
ScimUsersExistFunc ScimUsersExistFunc
ScimUsersExistFuncInvoked bool
ReplaceScimUserFunc ReplaceScimUserFunc
ReplaceScimUserFuncInvoked bool
@ -7869,6 +7874,13 @@ func (s *DataStore) ScimUserByHostID(ctx context.Context, hostID uint) (*fleet.S
return s.ScimUserByHostIDFunc(ctx, hostID)
}
func (s *DataStore) ScimUsersExist(ctx context.Context, ids []uint) (bool, error) {
s.mu.Lock()
s.ScimUsersExistFuncInvoked = true
s.mu.Unlock()
return s.ScimUsersExistFunc(ctx, ids)
}
func (s *DataStore) ReplaceScimUser(ctx context.Context, user *fleet.ScimUser) error {
s.mu.Lock()
s.ReplaceScimUserFuncInvoked = true
@ -7925,7 +7937,7 @@ func (s *DataStore) DeleteScimGroup(ctx context.Context, id uint) error {
return s.DeleteScimGroupFunc(ctx, id)
}
func (s *DataStore) ListScimGroups(ctx context.Context, opts fleet.ScimListOptions) (groups []fleet.ScimGroup, totalResults uint, err error) {
func (s *DataStore) ListScimGroups(ctx context.Context, opts fleet.ScimGroupsListOptions) (groups []fleet.ScimGroup, totalResults uint, err error) {
s.mu.Lock()
s.ListScimGroupsFuncInvoked = true
s.mu.Unlock()