mirror of
https://github.com/fleetdm/fleet
synced 2026-05-24 01:18:42 +00:00
SCIM Entra ID support (#28832)
For #28196
This PR adds full patching for SCIM Users and Groups, and adds the
ability to filter Groups by displayName.
The changes have been tested with [Entra ID SCIM
Validator](67dfd91c0c/docs/Contributing/SCIM-integration.md (entra-id-integration))
and Okta SCIM 2.0 SPEC Test (to make sure we didn't break Okta).
# Checklist for submitter
- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
- [x] Added/updated automated tests
- [x] A detailed QA plan exists on the associated ticket (if it isn't
there, work with the product group's QA engineer to add it)
- [x] Manual QA for all new/changed functionality
This commit is contained in:
parent
89c0386572
commit
6f9030ee3c
14 changed files with 2522 additions and 215 deletions
1
changes/28196-SCIM-for-Entra-ID
Normal file
1
changes/28196-SCIM-for-Entra-ID
Normal file
|
|
@ -0,0 +1 @@
|
|||
Added ability to sync end user's IdP information with Microsoft Entra ID using SCIM protocol.
|
||||
|
|
@ -51,12 +51,20 @@ Run test using [Runscope](https://www.runscope.com/). See [instructions](https:/
|
|||
|
||||
### Testing Entra ID integration
|
||||
|
||||
Use [scimvalidator.microsoft.com](https://scimvalidator.microsoft.com/). Only test the attributes that we have implemented. To see our supported attributes, check the schema:
|
||||
Use [scimvalidator.microsoft.com](https://scimvalidator.microsoft.com/). Only test the attributes that we have implemented.
|
||||
|
||||

|
||||

|
||||
|
||||
To see our supported attributes, check the schema:
|
||||
```
|
||||
GET https://localhost:8080/api/latest/fleet/scim/Schemas
|
||||
```
|
||||
|
||||
Results (2025/05/06)
|
||||
|
||||

|
||||
|
||||
## Authentication
|
||||
|
||||
We use same authentication as API. HTTP header: `Authorization: Bearer xyz`
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 56 KiB |
BIN
docs/Contributing/assets/SCIM-Entra-ID-Validator-results.png
Normal file
BIN
docs/Contributing/assets/SCIM-Entra-ID-Validator-results.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 257 KiB |
File diff suppressed because it is too large
Load diff
|
|
@ -1,8 +1,10 @@
|
|||
package scim
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
|
|
@ -12,6 +14,7 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
kitlog "github.com/go-kit/log"
|
||||
"github.com/go-kit/log/level"
|
||||
"github.com/scim2/filter-parser/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -145,7 +148,6 @@ func createGroupResource(group *fleet.ScimGroup) scim.Resource {
|
|||
for _, userID := range group.ScimUsers {
|
||||
members = append(members, map[string]interface{}{
|
||||
"value": scimUserID(userID),
|
||||
"$ref": "Users/" + scimUserID(userID),
|
||||
"type": "User",
|
||||
})
|
||||
}
|
||||
|
|
@ -170,15 +172,37 @@ func (g *GroupHandler) GetAll(r *http.Request, params scim.ListRequestParams) (s
|
|||
count = maxResults
|
||||
}
|
||||
|
||||
opts := fleet.ScimListOptions{
|
||||
StartIndex: uint(startIndex), // nolint:gosec // ignore G115
|
||||
PerPage: uint(count), // nolint:gosec // ignore G115
|
||||
opts := fleet.ScimGroupsListOptions{
|
||||
ScimListOptions: fleet.ScimListOptions{
|
||||
StartIndex: uint(startIndex), // nolint:gosec // ignore G115
|
||||
PerPage: uint(count), // nolint:gosec // ignore G115
|
||||
},
|
||||
}
|
||||
|
||||
resourceFilter := r.URL.Query().Get("filter")
|
||||
if resourceFilter != "" {
|
||||
level.Info(g.logger).Log("msg", "group filter not supported", "filter", resourceFilter)
|
||||
return scim.Page{}, nil
|
||||
expr, err := filter.ParseAttrExp([]byte(resourceFilter))
|
||||
if err != nil {
|
||||
level.Error(g.logger).Log("msg", "failed to parse filter", "filter", resourceFilter, "err", err)
|
||||
return scim.Page{}, errors.ScimErrorInvalidFilter
|
||||
}
|
||||
if !strings.EqualFold(expr.AttributePath.String(), "displayName") || expr.Operator != "eq" {
|
||||
level.Info(g.logger).Log("msg", "unsupported filter", "filter", resourceFilter)
|
||||
return scim.Page{}, nil
|
||||
}
|
||||
displayName, ok := expr.CompareValue.(string)
|
||||
if !ok {
|
||||
level.Error(g.logger).Log("msg", "unsupported value", "value", expr.CompareValue)
|
||||
return scim.Page{}, nil
|
||||
}
|
||||
|
||||
// Decode URL-encoded characters
|
||||
displayName, err = url.QueryUnescape(displayName)
|
||||
if err != nil {
|
||||
level.Error(g.logger).Log("msg", "failed to decode displayName", "displayName", displayName, "err", err)
|
||||
return scim.Page{}, nil
|
||||
}
|
||||
opts.DisplayNameFilter = &displayName
|
||||
}
|
||||
|
||||
groups, totalResults, err := g.ds.ListScimGroups(r.Context(), opts)
|
||||
|
|
@ -256,8 +280,7 @@ func (g *GroupHandler) Delete(r *http.Request, id string) error {
|
|||
}
|
||||
|
||||
// Patch
|
||||
// Only supporting replacing the "displayName" attribute.
|
||||
// Note: Okta does not use PATCH endpoint to update groups (2025/04/01)
|
||||
// Supporting add/replace/remove operations for "displayName", "externalId", and "members" attributes.
|
||||
func (g *GroupHandler) Patch(r *http.Request, id string, operations []scim.PatchOperation) (scim.Resource, error) {
|
||||
idUint, err := extractGroupIDFromValue(id)
|
||||
if err != nil {
|
||||
|
|
@ -275,53 +298,312 @@ func (g *GroupHandler) Patch(r *http.Request, id string, operations []scim.Patch
|
|||
}
|
||||
|
||||
for _, op := range operations {
|
||||
if op.Op != "replace" {
|
||||
if op.Op != scim.PatchOperationAdd && op.Op != scim.PatchOperationReplace && op.Op != scim.PatchOperationRemove {
|
||||
level.Info(g.logger).Log("msg", "unsupported patch operation", "op", op.Op)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
switch {
|
||||
case op.Path == nil:
|
||||
if op.Op == scim.PatchOperationRemove {
|
||||
level.Info(g.logger).Log("msg", "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
newValues, ok := op.Value.(map[string]interface{})
|
||||
if !ok {
|
||||
level.Info(g.logger).Log("msg", "unsupported patch value", "value", op.Value)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
if len(newValues) != 1 {
|
||||
level.Info(g.logger).Log("msg", "too many patch values", "value", op.Value)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
for k, v := range newValues {
|
||||
switch k {
|
||||
case externalIdAttr:
|
||||
err = g.patchExternalId(op.Op, v, group)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case displayNameAttr:
|
||||
err = g.patchDisplayName(op.Op, v, group)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case membersAttr:
|
||||
err = g.patchMembers(r.Context(), op.Op, v, group)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
default:
|
||||
level.Info(g.logger).Log("msg", "unsupported patch value field", "field", k)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
}
|
||||
displayName, err := getRequiredResource[string](newValues, displayNameAttr)
|
||||
case op.Path.String() == externalIdAttr:
|
||||
err = g.patchExternalId(op.Op, op.Value, group)
|
||||
if err != nil {
|
||||
level.Info(g.logger).Log("msg", "failed to get active value", "value", op.Value)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
group.DisplayName = displayName
|
||||
case op.Path.String() == displayNameAttr:
|
||||
displayName, ok := op.Value.(string)
|
||||
if !ok {
|
||||
level.Error(g.logger).Log("msg", "unsupported 'displayName' patch value", "value", op.Value)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
err = g.patchDisplayName(op.Op, op.Value, group)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.String() == membersAttr:
|
||||
err = g.patchMembers(r.Context(), op.Op, op.Value, group)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.AttributePath.String() == membersAttr:
|
||||
err = g.patchMembersWithPathFiltering(r.Context(), op, group)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
group.DisplayName = displayName
|
||||
default:
|
||||
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
}
|
||||
|
||||
err = g.ds.ReplaceScimGroup(r.Context(), group)
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(g.logger).Log("msg", "failed to find group to patch", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(g.logger).Log("msg", "failed to patch group", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
if len(operations) != 0 {
|
||||
err = g.ds.ReplaceScimGroup(r.Context(), group)
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(g.logger).Log("msg", "failed to find group to patch", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(g.logger).Log("msg", "failed to patch group", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return createGroupResource(group), nil
|
||||
}
|
||||
|
||||
func (g *GroupHandler) patchExternalId(op string, v interface{}, group *fleet.ScimGroup) error {
|
||||
if op == scim.PatchOperationRemove || v == nil {
|
||||
group.ExternalID = nil
|
||||
return nil
|
||||
}
|
||||
externalId, ok := v.(string)
|
||||
if !ok {
|
||||
level.Info(g.logger).Log("msg", fmt.Sprintf("unsupported '%s' value", externalIdAttr), "value", v)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
|
||||
}
|
||||
group.ExternalID = &externalId
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GroupHandler) patchDisplayName(op string, v interface{}, group *fleet.ScimGroup) error {
|
||||
if op == scim.PatchOperationRemove {
|
||||
level.Info(g.logger).Log("msg", "cannot remove required attribute", "attribute", displayNameAttr)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
displayName, ok := v.(string)
|
||||
if !ok {
|
||||
level.Info(g.logger).Log("msg", fmt.Sprintf("unsupported '%s' value", displayNameAttr), "value", v)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
|
||||
}
|
||||
if displayName == "" {
|
||||
level.Info(g.logger).Log("msg", fmt.Sprintf("'%s' cannot be empty", displayNameAttr), "value", v)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
|
||||
}
|
||||
group.DisplayName = displayName
|
||||
return nil
|
||||
}
|
||||
|
||||
// patchMembers handles add/replace/remove operations for the members attribute
|
||||
func (g *GroupHandler) patchMembers(ctx context.Context, op string, v interface{}, group *fleet.ScimGroup) error {
|
||||
if op == scim.PatchOperationRemove {
|
||||
// Remove all members
|
||||
group.ScimUsers = []uint{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// For add and replace operations, we need to extract the member IDs
|
||||
var membersList []interface{}
|
||||
|
||||
// Handle different value formats
|
||||
switch val := v.(type) {
|
||||
case []interface{}:
|
||||
// Direct array of members
|
||||
membersList = val
|
||||
case map[string]interface{}:
|
||||
// Single member as a map
|
||||
membersList = []interface{}{val}
|
||||
case []map[string]interface{}:
|
||||
// Array of member maps
|
||||
for _, m := range val {
|
||||
membersList = append(membersList, m)
|
||||
}
|
||||
default:
|
||||
level.Info(g.logger).Log("msg", "unsupported members value format", "value", v)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
|
||||
}
|
||||
|
||||
// Process the members
|
||||
userIDs := make([]uint, 0, len(membersList))
|
||||
valueStrings := make([]string, 0, len(membersList))
|
||||
|
||||
for _, memberIntf := range membersList {
|
||||
member, ok := memberIntf.(map[string]interface{})
|
||||
if !ok {
|
||||
level.Info(g.logger).Log("msg", "member must be an object", "member", memberIntf)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", memberIntf)})
|
||||
}
|
||||
|
||||
// Get the value attribute which contains the user ID
|
||||
valueIntf, ok := member["value"]
|
||||
if !ok || valueIntf == nil {
|
||||
level.Info(g.logger).Log("msg", "member missing value attribute", "member", member)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", member)})
|
||||
}
|
||||
|
||||
valueStr, ok := valueIntf.(string)
|
||||
if !ok {
|
||||
level.Info(g.logger).Log("msg", "member value must be a string", "value", valueIntf)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", valueIntf)})
|
||||
}
|
||||
valueStrings = append(valueStrings, valueStr)
|
||||
|
||||
// Extract user ID from the value
|
||||
userID, err := extractUserIDFromValue(valueStr)
|
||||
if err != nil {
|
||||
level.Info(g.logger).Log("msg", "invalid user ID format", "value", valueStr, "err", err)
|
||||
return errors.ScimErrorBadParams([]string{valueStr})
|
||||
}
|
||||
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
|
||||
// Verify all users exist in a single database call
|
||||
if len(userIDs) > 0 {
|
||||
allExist, err := g.ds.ScimUsersExist(ctx, userIDs)
|
||||
if err != nil {
|
||||
level.Error(g.logger).Log("msg", "error checking users existence", "err", err)
|
||||
return err
|
||||
}
|
||||
if !allExist {
|
||||
level.Info(g.logger).Log("msg", "one or more users not found", "userIDs", userIDs)
|
||||
return errors.ScimErrorBadParams(valueStrings)
|
||||
}
|
||||
}
|
||||
|
||||
// For add operation, append to existing members
|
||||
if op == scim.PatchOperationAdd {
|
||||
// Create a map to track existing user IDs to avoid duplicates
|
||||
existingUsers := make(map[uint]bool)
|
||||
for _, id := range group.ScimUsers {
|
||||
existingUsers[id] = true
|
||||
}
|
||||
|
||||
// Add new users that don't already exist in the group
|
||||
for _, id := range userIDs {
|
||||
if !existingUsers[id] {
|
||||
group.ScimUsers = append(group.ScimUsers, id)
|
||||
existingUsers[id] = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For replace operation, replace all members
|
||||
group.ScimUsers = userIDs
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// patchMembersWithPathFiltering handles patch operations with path filtering for members
|
||||
// This supports paths like members[value eq "422"] for add/replace/remove operations
|
||||
func (g *GroupHandler) patchMembersWithPathFiltering(ctx context.Context, op scim.PatchOperation, group *fleet.ScimGroup) error {
|
||||
memberID, err := g.getMemberID(op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the member exists in the group
|
||||
memberFound := false
|
||||
var memberIndex int
|
||||
for i, id := range group.ScimUsers {
|
||||
if id == memberID {
|
||||
memberIndex = i
|
||||
memberFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// For remove operations, remove the member if found
|
||||
if op.Op == scim.PatchOperationRemove {
|
||||
if !memberFound {
|
||||
level.Info(g.logger).Log("msg", "member not found", "member_id", memberID, "op", fmt.Sprintf("%v", op))
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
group.ScimUsers = append(group.ScimUsers[:memberIndex], group.ScimUsers[memberIndex+1:]...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// For add operations, add the member if not found
|
||||
if op.Op == scim.PatchOperationAdd && !memberFound {
|
||||
// Verify the user exists
|
||||
userExists, err := g.ds.ScimUsersExist(ctx, []uint{memberID})
|
||||
if err != nil {
|
||||
level.Error(g.logger).Log("msg", "error checking user existence", "err", err)
|
||||
return err
|
||||
}
|
||||
if !userExists {
|
||||
level.Info(g.logger).Log("msg", "user not found", "user_id", memberID)
|
||||
return errors.ScimErrorBadParams([]string{scimUserID(memberID)})
|
||||
}
|
||||
group.ScimUsers = append(group.ScimUsers, memberID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// For replace operations with a value
|
||||
if op.Op == scim.PatchOperationReplace {
|
||||
if !memberFound {
|
||||
level.Info(g.logger).Log("msg", "member not found for replace operation", "members.value", memberID, "op", fmt.Sprintf("%v", op))
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
// If the value is nil or an empty object, remove the member
|
||||
if op.Value == nil {
|
||||
group.ScimUsers = append(group.ScimUsers[:memberIndex], group.ScimUsers[memberIndex+1:]...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise, we don't change anything since we're already filtering by the member ID
|
||||
// and there are no other attributes to modify for a member
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getMemberID extracts the member ID from a path expression like members[value eq "422"]
|
||||
func (g *GroupHandler) getMemberID(op scim.PatchOperation) (uint, error) {
|
||||
attrExpression, ok := op.Path.ValueExpression.(*filter.AttributeExpression)
|
||||
if !ok {
|
||||
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path)
|
||||
return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
// Only matching by member value (user ID) is supported
|
||||
if attrExpression.AttributePath.String() != valueAttr || attrExpression.Operator != filter.EQ {
|
||||
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path, "expression", attrExpression.AttributePath.String())
|
||||
return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
memberIDStr, ok := attrExpression.CompareValue.(string)
|
||||
if !ok {
|
||||
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path, "compare_value", attrExpression.CompareValue)
|
||||
return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
// Extract user ID from the value
|
||||
userID, err := extractUserIDFromValue(memberIDStr)
|
||||
if err != nil {
|
||||
level.Info(g.logger).Log("msg", "invalid user ID format", "value", memberIDStr, "err", err)
|
||||
return 0, errors.ScimErrorBadParams([]string{memberIDStr})
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func scimGroupID(groupID uint) string {
|
||||
return fmt.Sprintf("group-%d", groupID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ func RegisterSCIM(
|
|||
config := scim.ServiceProviderConfig{
|
||||
DocumentationURI: optional.NewString("https://fleetdm.com/docs/get-started/why-fleet"),
|
||||
MaxResults: maxResults,
|
||||
SupportFiltering: true,
|
||||
SupportPatch: true,
|
||||
}
|
||||
|
||||
// The common attributes are id, externalId, and meta.
|
||||
|
|
@ -136,18 +138,14 @@ func RegisterSCIM(
|
|||
Mutability: schema.AttributeMutabilityImmutable(),
|
||||
Name: "value",
|
||||
}),
|
||||
schema.SimpleReferenceParams(schema.ReferenceParams{
|
||||
Description: optional.NewString("The URI corresponding to a SCIM resource that is a member of this Group."),
|
||||
Mutability: schema.AttributeMutabilityImmutable(),
|
||||
Name: "$ref",
|
||||
ReferenceTypes: []schema.AttributeReferenceType{"User"},
|
||||
}),
|
||||
schema.SimpleStringParams(schema.StringParams{
|
||||
CanonicalValues: []string{"User"},
|
||||
Description: optional.NewString("A label indicating the type of resource, e.g., 'User' or 'Group'."),
|
||||
Mutability: schema.AttributeMutabilityImmutable(),
|
||||
Name: "type",
|
||||
}),
|
||||
// Note (2025/05/06): Microsoft does not properly support $ref attribute on group members
|
||||
// https://learn.microsoft.com/en-us/answers/questions/1457148/scim-validator-patch-group-remove-member-test-comp
|
||||
},
|
||||
}),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -406,9 +406,7 @@ func (u *UserHandler) Delete(r *http.Request, id string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Patch
|
||||
// Okta only requires patching the "active" attribute:
|
||||
// https://developer.okta.com/docs/api/openapi/okta-scim/guides/scim-20/#update-a-specific-user-patch
|
||||
// Patch - https://datatracker.ietf.org/doc/html/rfc7644#section-3.5.2
|
||||
func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchOperation) (scim.Resource, error) {
|
||||
idUint, err := extractUserIDFromValue(id)
|
||||
if err != nil {
|
||||
|
|
@ -426,12 +424,17 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
|
|||
}
|
||||
|
||||
for _, op := range operations {
|
||||
if op.Op != "replace" {
|
||||
if op.Op != scim.PatchOperationAdd && op.Op != scim.PatchOperationReplace && op.Op != scim.PatchOperationRemove {
|
||||
level.Info(u.logger).Log("msg", "unsupported patch operation", "op", op.Op)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
switch {
|
||||
// If path is not specified, we look for the path in the value attribute.
|
||||
case op.Path == nil:
|
||||
if op.Op == scim.PatchOperationRemove {
|
||||
level.Info(u.logger).Log("msg", "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
newValues, ok := op.Value.(map[string]interface{})
|
||||
if !ok {
|
||||
level.Info(u.logger).Log("msg", "unsupported patch value", "value", op.Value)
|
||||
|
|
@ -439,27 +442,28 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
|
|||
}
|
||||
for k, v := range newValues {
|
||||
switch k {
|
||||
case externalIdAttr:
|
||||
err = u.patchExternalId(op.Op, v, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case userNameAttr:
|
||||
err = u.patchUserName(v, user)
|
||||
err = u.patchUserName(op.Op, v, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case activeAttr:
|
||||
if v == nil {
|
||||
user.Active = nil
|
||||
continue
|
||||
}
|
||||
err = u.patchActive(v, user)
|
||||
err = u.patchActive(op.Op, v, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case nameAttr + "." + givenNameAttr:
|
||||
err = u.patchGivenName(v, user)
|
||||
err = u.patchGivenName(op.Op, v, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case nameAttr + "." + familyNameAttr:
|
||||
err = u.patchFamilyName(v, user)
|
||||
err = u.patchFamilyName(op.Op, v, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
|
@ -478,27 +482,28 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
|
|||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
}
|
||||
case op.Path.String() == externalIdAttr:
|
||||
err = u.patchExternalId(op.Op, op.Value, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.String() == userNameAttr:
|
||||
err = u.patchUserName(op.Value, user)
|
||||
err = u.patchUserName(op.Op, op.Value, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.String() == activeAttr:
|
||||
if op.Value == nil {
|
||||
user.Active = nil
|
||||
continue
|
||||
}
|
||||
err = u.patchActive(op.Value, user)
|
||||
err = u.patchActive(op.Op, op.Value, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.String() == nameAttr+"."+givenNameAttr:
|
||||
err = u.patchGivenName(op.Value, user)
|
||||
err = u.patchGivenName(op.Op, op.Value, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.String() == nameAttr+"."+familyNameAttr:
|
||||
err = u.patchFamilyName(op.Value, user)
|
||||
err = u.patchFamilyName(op.Op, op.Value, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
|
@ -513,84 +518,10 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
|
|||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.AttributePath.String() == emailsAttr:
|
||||
emailType, err := u.getEmailType(op)
|
||||
err = u.patchEmailsWithPathFiltering(op, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
emailFound := false
|
||||
var emailIndex int
|
||||
for i, email := range user.Emails {
|
||||
if email.Type != nil && *email.Type == emailType {
|
||||
emailIndex = i
|
||||
emailFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !emailFound {
|
||||
level.Info(u.logger).Log("msg", "email not found", "email_type", emailType)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
if op.Path.SubAttribute == nil {
|
||||
// The value for emails comes in as an array.
|
||||
userEmails, ok := op.Value.([]interface{})
|
||||
if !ok {
|
||||
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
if len(userEmails) == 0 {
|
||||
user.Emails = slices.Delete(user.Emails, emailIndex, emailIndex)
|
||||
continue
|
||||
}
|
||||
if len(userEmails) != 1 {
|
||||
level.Info(u.logger).Log("msg", "only 1 email should be present for replacement", "emails", userEmails)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
userEmail, err := u.extractEmail(userEmails[0], op)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
// If setting primary to true, then unset true from other emails
|
||||
if userEmail.Primary != nil && *userEmail.Primary {
|
||||
clearPrimaryFlagFromEmails(user)
|
||||
}
|
||||
user.Emails[emailIndex] = userEmail
|
||||
continue
|
||||
}
|
||||
switch *op.Path.SubAttribute {
|
||||
case primaryAttr:
|
||||
if op.Value == nil {
|
||||
user.Emails[emailIndex].Primary = nil
|
||||
continue
|
||||
}
|
||||
primary, err := getConcreteType[bool](u, op.Value, primaryAttr)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
// If setting primary to true, then unset true from other emails
|
||||
if primary {
|
||||
clearPrimaryFlagFromEmails(user)
|
||||
}
|
||||
user.Emails[emailIndex].Primary = &primary
|
||||
case valueAttr:
|
||||
value, err := getConcreteType[string](u, op.Value, valueAttr)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
user.Emails[emailIndex].Email = value
|
||||
case typeAttr:
|
||||
if op.Value == nil {
|
||||
user.Emails[emailIndex].Type = nil
|
||||
continue
|
||||
}
|
||||
newEmailType, err := getConcreteType[string](u, op.Value, typeAttr)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
user.Emails[emailIndex].Type = &newEmailType
|
||||
default:
|
||||
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
default:
|
||||
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
|
|
@ -612,6 +543,146 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
|
|||
return createUserResource(user), nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user *fleet.ScimUser) error {
|
||||
emailType, err := u.getEmailType(op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
emailFound := false
|
||||
var emailIndex int
|
||||
for i, email := range user.Emails {
|
||||
if email.Type != nil && *email.Type == emailType {
|
||||
emailIndex = i
|
||||
emailFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !emailFound && op.Op != scim.PatchOperationAdd {
|
||||
level.Info(u.logger).Log("msg", "email not found", "email_type", emailType, "op", fmt.Sprintf("%v", op))
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
if op.Path.SubAttribute == nil {
|
||||
if op.Op == scim.PatchOperationRemove {
|
||||
user.Emails = slices.Delete(user.Emails, emailIndex, emailIndex+1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// For add and replace operations, we need to extract the emails
|
||||
var emailsList []interface{}
|
||||
// Handle different value formats
|
||||
switch val := op.Value.(type) {
|
||||
case []interface{}:
|
||||
// Direct array of members
|
||||
emailsList = val
|
||||
case map[string]interface{}:
|
||||
// Single member as a map
|
||||
emailsList = []interface{}{val}
|
||||
default:
|
||||
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
switch op.Op {
|
||||
case scim.PatchOperationReplace:
|
||||
if len(emailsList) == 0 {
|
||||
user.Emails = slices.Delete(user.Emails, emailIndex, emailIndex+1)
|
||||
return nil
|
||||
}
|
||||
if len(emailsList) != 1 {
|
||||
level.Info(u.logger).Log("msg", "only 1 email should be present for replacement", "emails", emailsList)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
userEmail, err := u.extractEmail(emailsList[0], op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If setting primary to true, then unset true from other emails
|
||||
if userEmail.Primary != nil && *userEmail.Primary {
|
||||
clearPrimaryFlagFromEmails(user)
|
||||
}
|
||||
user.Emails[emailIndex] = userEmail
|
||||
case scim.PatchOperationAdd:
|
||||
if len(emailsList) == 0 {
|
||||
level.Info(u.logger).Log("msg", "no emails provided to add", "emails", emailsList)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
var newEmails []fleet.ScimUserEmail
|
||||
for e := range emailsList {
|
||||
userEmail, err := u.extractEmail(emailsList[e], op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userEmail.Type = &emailType
|
||||
newEmails = append(newEmails, userEmail)
|
||||
}
|
||||
primaryExists, err := u.checkEmailPrimary(newEmails)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if primaryExists {
|
||||
clearPrimaryFlagFromEmails(user)
|
||||
}
|
||||
user.Emails = append(user.Emails, newEmails...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if op.Op == scim.PatchOperationAdd && !emailFound {
|
||||
user.Emails = append(user.Emails, fleet.ScimUserEmail{
|
||||
Type: ptr.String(emailType),
|
||||
})
|
||||
emailIndex = len(user.Emails) - 1
|
||||
}
|
||||
switch *op.Path.SubAttribute {
|
||||
case primaryAttr:
|
||||
if op.Op == scim.PatchOperationRemove {
|
||||
user.Emails[emailIndex].Primary = nil
|
||||
return nil
|
||||
}
|
||||
if op.Value == nil {
|
||||
user.Emails[emailIndex].Primary = nil
|
||||
return nil
|
||||
}
|
||||
primary, err := getConcreteType[bool](u, op.Value, primaryAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If setting primary to true, then unset true from other emails
|
||||
if primary {
|
||||
clearPrimaryFlagFromEmails(user)
|
||||
}
|
||||
user.Emails[emailIndex].Primary = &primary
|
||||
case valueAttr:
|
||||
if op.Op == scim.PatchOperationRemove {
|
||||
// The operation of removing an email value doesn't make sense, but we allow it.
|
||||
user.Emails[emailIndex].Email = ""
|
||||
return nil
|
||||
}
|
||||
value, err := getConcreteType[string](u, op.Value, valueAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.Emails[emailIndex].Email = value
|
||||
case typeAttr:
|
||||
if op.Op == scim.PatchOperationRemove {
|
||||
user.Emails[emailIndex].Type = nil
|
||||
return nil
|
||||
}
|
||||
if op.Value == nil {
|
||||
user.Emails[emailIndex].Type = nil
|
||||
return nil
|
||||
}
|
||||
newEmailType, err := getConcreteType[string](u, op.Value, typeAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.Emails[emailIndex].Type = &newEmailType
|
||||
default:
|
||||
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) getEmailType(op scim.PatchOperation) (string, error) {
|
||||
attrExpression, ok := op.Path.ValueExpression.(*filter.AttributeExpression)
|
||||
if !ok {
|
||||
|
|
@ -641,7 +712,11 @@ func getConcreteType[T string | bool](u *UserHandler, v interface{}, name string
|
|||
return concreteType, nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchFamilyName(v interface{}, user *fleet.ScimUser) error {
|
||||
func (u *UserHandler) patchFamilyName(op string, v interface{}, user *fleet.ScimUser) error {
|
||||
if op == scim.PatchOperationRemove {
|
||||
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr+"."+familyNameAttr)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
familyName, err := getConcreteType[string](u, v, nameAttr+"."+familyNameAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -650,7 +725,11 @@ func (u *UserHandler) patchFamilyName(v interface{}, user *fleet.ScimUser) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchGivenName(v interface{}, user *fleet.ScimUser) error {
|
||||
func (u *UserHandler) patchGivenName(op string, v interface{}, user *fleet.ScimUser) error {
|
||||
if op == scim.PatchOperationRemove {
|
||||
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr+"."+givenNameAttr)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
givenName, err := getConcreteType[string](u, v, nameAttr+"."+givenNameAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -659,7 +738,11 @@ func (u *UserHandler) patchGivenName(v interface{}, user *fleet.ScimUser) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchActive(v interface{}, user *fleet.ScimUser) error {
|
||||
func (u *UserHandler) patchActive(op string, v interface{}, user *fleet.ScimUser) error {
|
||||
if op == scim.PatchOperationRemove || v == nil {
|
||||
user.Active = nil
|
||||
return nil
|
||||
}
|
||||
active, err := getConcreteType[bool](u, v, activeAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -668,7 +751,24 @@ func (u *UserHandler) patchActive(v interface{}, user *fleet.ScimUser) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchUserName(v interface{}, user *fleet.ScimUser) error {
|
||||
func (u *UserHandler) patchExternalId(op string, v interface{}, user *fleet.ScimUser) error {
|
||||
if op == scim.PatchOperationRemove || v == nil {
|
||||
user.ExternalID = nil
|
||||
return nil
|
||||
}
|
||||
externalId, err := getConcreteType[string](u, v, externalIdAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.ExternalID = ptr.String(externalId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchUserName(op string, v interface{}, user *fleet.ScimUser) error {
|
||||
if op == scim.PatchOperationRemove {
|
||||
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", userNameAttr)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
userName, err := getConcreteType[string](u, v, userNameAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -690,43 +790,69 @@ func clearPrimaryFlagFromEmails(user *fleet.ScimUser) {
|
|||
}
|
||||
|
||||
func (u *UserHandler) patchEmails(v interface{}, op scim.PatchOperation, user *fleet.ScimUser) error {
|
||||
emailsValue, ok := v.([]interface{})
|
||||
if !ok {
|
||||
if op.Op == scim.PatchOperationRemove {
|
||||
user.Emails = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// For add and replace operations, we need to extract the emails
|
||||
var emailsList []interface{}
|
||||
// Handle different value formats
|
||||
switch val := v.(type) {
|
||||
case []interface{}:
|
||||
// Direct array of members
|
||||
emailsList = val
|
||||
case map[string]interface{}:
|
||||
// Single member as a map
|
||||
emailsList = []interface{}{val}
|
||||
default:
|
||||
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
if op.Op == scim.PatchOperationAdd && len(emailsList) == 0 {
|
||||
level.Info(u.logger).Log("msg", "no emails provided to add", "emails", emailsList)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
// Convert the emails to the expected format
|
||||
userEmails := make([]fleet.ScimUserEmail, 0, len(emailsValue))
|
||||
for _, emailIntf := range emailsValue {
|
||||
userEmails := make([]fleet.ScimUserEmail, 0, len(emailsList))
|
||||
for _, emailIntf := range emailsList {
|
||||
userEmail, err := u.extractEmail(emailIntf, op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userEmails = append(userEmails, userEmail)
|
||||
}
|
||||
|
||||
err := u.checkEmailPrimary(userEmails)
|
||||
primaryExists, err := u.checkEmailPrimary(userEmails)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if op.Op == scim.PatchOperationAdd {
|
||||
if primaryExists {
|
||||
// Clear the primary flag from current emails because we are merging the two email lists and a new email has that flag.
|
||||
clearPrimaryFlagFromEmails(user)
|
||||
}
|
||||
userEmails = append(user.Emails, userEmails...)
|
||||
}
|
||||
|
||||
user.Emails = userEmails
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkEmailPrimary ensures at most one email is marked as primary
|
||||
func (u *UserHandler) checkEmailPrimary(userEmails []fleet.ScimUserEmail) error {
|
||||
func (u *UserHandler) checkEmailPrimary(userEmails []fleet.ScimUserEmail) (bool, error) {
|
||||
primaryEmailCount := 0
|
||||
for _, email := range userEmails {
|
||||
if email.Primary != nil && *email.Primary {
|
||||
primaryEmailCount++
|
||||
if primaryEmailCount > 1 {
|
||||
level.Info(u.logger).Log("msg", "multiple primary emails found")
|
||||
return errors.ScimErrorBadParams([]string{"Only one email can be marked as primary"})
|
||||
return false, errors.ScimErrorBadParams([]string{"Only one email can be marked as primary"})
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return primaryEmailCount > 0, nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) extractEmail(emailIntf interface{}, op scim.PatchOperation) (fleet.ScimUserEmail, error) {
|
||||
|
|
@ -770,6 +896,10 @@ func (u *UserHandler) extractEmail(emailIntf interface{}, op scim.PatchOperation
|
|||
}
|
||||
|
||||
func (u *UserHandler) patchName(v interface{}, op scim.PatchOperation, user *fleet.ScimUser) error {
|
||||
if op.Op == scim.PatchOperationRemove {
|
||||
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
name, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", nameAttr), "value", op.Value)
|
||||
|
|
|
|||
|
|
@ -222,6 +222,24 @@ func (ds *Datastore) ReplaceScimUser(ctx context.Context, user *fleet.ScimUser)
|
|||
return err
|
||||
}
|
||||
|
||||
// Validate that at most one email is marked as primary
|
||||
primaryCount := 0
|
||||
for _, email := range user.Emails {
|
||||
if email.Primary != nil && *email.Primary {
|
||||
primaryCount++
|
||||
}
|
||||
}
|
||||
if primaryCount > 1 {
|
||||
return ctxerr.New(ctx, "only one email can be marked as primary")
|
||||
}
|
||||
|
||||
// Get current emails and check if they need to be updated
|
||||
currentEmails, err := ds.getScimUserEmails(ctx, user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
emailsNeedUpdate := emailsRequireUpdate(currentEmails, user.Emails)
|
||||
|
||||
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
||||
// load the username before updating the user, to check if it changed
|
||||
var oldUsername string
|
||||
|
|
@ -265,24 +283,23 @@ func (ds *Datastore) ReplaceScimUser(ctx context.Context, user *fleet.ScimUser)
|
|||
}
|
||||
usernameChanged := oldUsername != user.UserName
|
||||
|
||||
// We assume that email is not blank/null.
|
||||
// However, we do not assume that email/type are unique for a user. To keep the code simple, we:
|
||||
// 1. Delete all existing emails
|
||||
// 2. Insert all new emails
|
||||
// This is less efficient and can be optimized if we notice a load on these tables in production.
|
||||
// Only update emails if they've changed
|
||||
if emailsNeedUpdate {
|
||||
// We assume that email is not blank/null.
|
||||
// However, we do not assume that email/type are unique for a user. To keep the code simple, we:
|
||||
// 1. Delete all existing emails
|
||||
// 2. Insert all new emails
|
||||
// This is less efficient and can be optimized if we notice a load on these tables in production.
|
||||
|
||||
// TODO: Check if emails need to be updated at all. If so, update the user updated_at timestamp if emails have been updated
|
||||
// TODO: Check that only 1 email is primary
|
||||
|
||||
const deleteEmailsQuery = `DELETE FROM scim_user_emails WHERE scim_user_id = ?`
|
||||
_, err = tx.ExecContext(ctx, deleteEmailsQuery, user.ID)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "delete scim user emails")
|
||||
}
|
||||
|
||||
err = insertEmails(ctx, tx, user)
|
||||
if err != nil {
|
||||
return err
|
||||
const deleteEmailsQuery = `DELETE FROM scim_user_emails WHERE scim_user_id = ?`
|
||||
_, err = tx.ExecContext(ctx, deleteEmailsQuery, user.ID)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "delete scim user emails")
|
||||
}
|
||||
err = insertEmails(ctx, tx, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Get the user's groups
|
||||
|
|
@ -867,7 +884,7 @@ func (ds *Datastore) DeleteScimGroup(ctx context.Context, id uint) error {
|
|||
}
|
||||
|
||||
// ListScimGroups retrieves a list of SCIM groups with pagination
|
||||
func (ds *Datastore) ListScimGroups(ctx context.Context, opts fleet.ScimListOptions) (groups []fleet.ScimGroup, totalResults uint, err error) {
|
||||
func (ds *Datastore) ListScimGroups(ctx context.Context, opts fleet.ScimGroupsListOptions) (groups []fleet.ScimGroup, totalResults uint, err error) {
|
||||
// Default pagination values if not provided
|
||||
if opts.StartIndex == 0 {
|
||||
opts.StartIndex = 1
|
||||
|
|
@ -883,16 +900,25 @@ func (ds *Datastore) ListScimGroups(ctx context.Context, opts fleet.ScimListOpti
|
|||
FROM scim_groups
|
||||
`
|
||||
|
||||
// Add where clause based on filters
|
||||
var whereClause string
|
||||
var params []interface{}
|
||||
|
||||
if opts.DisplayNameFilter != nil {
|
||||
whereClause = " WHERE scim_groups.display_name = ?"
|
||||
params = append(params, *opts.DisplayNameFilter)
|
||||
}
|
||||
|
||||
// First, get the total count without pagination
|
||||
countQuery := "SELECT COUNT(DISTINCT id) FROM (" + baseQuery + ") AS filtered_groups"
|
||||
err = sqlx.GetContext(ctx, ds.reader(ctx), &totalResults, countQuery)
|
||||
countQuery := "SELECT COUNT(DISTINCT id) FROM (" + baseQuery + whereClause + ") AS filtered_groups"
|
||||
err = sqlx.GetContext(ctx, ds.reader(ctx), &totalResults, countQuery, params...)
|
||||
if err != nil {
|
||||
return nil, 0, ctxerr.Wrap(ctx, err, "count total scim groups")
|
||||
}
|
||||
|
||||
// Add pagination to the main query
|
||||
query := baseQuery + " ORDER BY scim_groups.id LIMIT ? OFFSET ?"
|
||||
params := []interface{}{opts.PerPage, opts.StartIndex - 1}
|
||||
query := baseQuery + whereClause + " ORDER BY scim_groups.id LIMIT ? OFFSET ?"
|
||||
params = append(params, opts.PerPage, opts.StartIndex-1)
|
||||
|
||||
// Execute the query
|
||||
err = sqlx.SelectContext(ctx, ds.reader(ctx), &groups, query, params...)
|
||||
|
|
@ -1208,3 +1234,76 @@ func triggerResendProfilesUsingVariables(ctx context.Context, tx sqlx.ExtContext
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// emailsRequireUpdate compares two slices of emails and returns true if they are different
|
||||
// and require an update in the database.
|
||||
func emailsRequireUpdate(currentEmails, newEmails []fleet.ScimUserEmail) bool {
|
||||
if len(currentEmails) != len(newEmails) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Create maps for efficient comparison
|
||||
currentEmailMap := make(map[string]fleet.ScimUserEmail)
|
||||
for i := range currentEmails {
|
||||
key := currentEmails[i].GenerateComparisonKey()
|
||||
currentEmailMap[key] = currentEmails[i]
|
||||
}
|
||||
|
||||
// Check if all new emails exist in current emails with the same attributes
|
||||
for i := range newEmails {
|
||||
key := newEmails[i].GenerateComparisonKey()
|
||||
if _, exists := currentEmailMap[key]; !exists {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ScimUsersExist checks if all the provided SCIM user IDs exist in the datastore
|
||||
// If the slice is empty, it returns true
|
||||
// This method processes IDs in batches to handle large numbers of IDs efficiently
|
||||
func (ds *Datastore) ScimUsersExist(ctx context.Context, ids []uint) (bool, error) {
|
||||
if len(ids) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Create a map to track which IDs we've found
|
||||
foundIDs := make(map[uint]bool, len(ids))
|
||||
|
||||
batchSize := 10000
|
||||
err := common_mysql.BatchProcessSimple(ids, batchSize, func(batchIDs []uint) error {
|
||||
query, args, err := sqlx.In(`
|
||||
SELECT id
|
||||
FROM scim_users
|
||||
WHERE id IN (?)
|
||||
`, batchIDs)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "prepare scim users exist batch query")
|
||||
}
|
||||
|
||||
var foundBatchIDs []uint
|
||||
err = sqlx.SelectContext(ctx, ds.reader(ctx), &foundBatchIDs, query, args...)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "check if scim users exist in batch")
|
||||
}
|
||||
|
||||
// Mark found IDs
|
||||
for _, id := range foundBatchIDs {
|
||||
foundIDs[id] = true
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Check if all IDs were found
|
||||
for _, id := range ids {
|
||||
if !foundIDs[id] {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ func TestScim(t *testing.T) {
|
|||
{"ScimUserByUserNameOrEmail", testScimUserByUserNameOrEmail},
|
||||
{"ScimUserByHostID", testScimUserByHostID},
|
||||
{"ReplaceScimUser", testReplaceScimUser},
|
||||
{"ReplaceScimUserEmails", testReplaceScimUserEmails},
|
||||
{"ReplaceScimUserValidation", testScimUserReplaceValidation},
|
||||
{"DeleteScimUser", testDeleteScimUser},
|
||||
{"ListScimUsers", testListScimUsers},
|
||||
|
|
@ -41,6 +42,7 @@ func TestScim(t *testing.T) {
|
|||
{"DeleteScimGroup", testDeleteScimGroup},
|
||||
{"ListScimGroups", testListScimGroups},
|
||||
{"ScimLastRequest", testScimLastRequest},
|
||||
{"ScimUsersExist", testScimUsersExist},
|
||||
{"TriggerResendIdPProfiles", testTriggerResendIdPProfiles},
|
||||
{"TriggerResendIdPProfilesOnTeam", testTriggerResendIdPProfilesOnTeam},
|
||||
}
|
||||
|
|
@ -423,6 +425,171 @@ func testReplaceScimUser(t *testing.T, ds *Datastore) {
|
|||
assert.True(t, fleet.IsNotFound(err))
|
||||
}
|
||||
|
||||
func testReplaceScimUserEmails(t *testing.T, ds *Datastore) {
|
||||
// Create a test user
|
||||
user := fleet.ScimUser{
|
||||
UserName: "email-test-user",
|
||||
ExternalID: ptr.String("ext-email-123"),
|
||||
GivenName: ptr.String("Email"),
|
||||
FamilyName: ptr.String("Test"),
|
||||
Active: ptr.Bool(true),
|
||||
Emails: []fleet.ScimUserEmail{
|
||||
{
|
||||
Email: "original.email@example.com",
|
||||
Primary: ptr.Bool(true),
|
||||
Type: ptr.String("work"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var err error
|
||||
user.ID, err = ds.CreateScimUser(t.Context(), &user)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Smoke test email optimization - replacing with the same emails should not update emails
|
||||
// First, get the current user to have a reference point
|
||||
currentUser, err := ds.ScimUserByID(t.Context(), user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a copy of the user with the same emails
|
||||
sameEmailsUser := fleet.ScimUser{
|
||||
ID: user.ID,
|
||||
UserName: "multi-update@example.com",
|
||||
ExternalID: ptr.String("ext-replace-456"),
|
||||
GivenName: ptr.String("Multiple"),
|
||||
FamilyName: ptr.String("Updates"),
|
||||
Active: ptr.Bool(true),
|
||||
Emails: currentUser.Emails, // Same emails as current user
|
||||
}
|
||||
|
||||
// Replace the user
|
||||
err = ds.ReplaceScimUser(t.Context(), &sameEmailsUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the user was updated correctly but emails remain the same
|
||||
sameEmailsResult, err := ds.ScimUserByID(t.Context(), user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, sameEmailsUser.UserName, sameEmailsResult.UserName)
|
||||
assert.Equal(t, sameEmailsUser.ExternalID, sameEmailsResult.ExternalID)
|
||||
assert.Equal(t, sameEmailsUser.GivenName, sameEmailsResult.GivenName)
|
||||
assert.Equal(t, sameEmailsUser.FamilyName, sameEmailsResult.FamilyName)
|
||||
assert.Equal(t, sameEmailsUser.Active, sameEmailsResult.Active)
|
||||
|
||||
// Verify emails are the same as before
|
||||
assert.Equal(t, len(currentUser.Emails), len(sameEmailsResult.Emails))
|
||||
for i := range currentUser.Emails {
|
||||
assert.Equal(t, currentUser.Emails[i].Email, sameEmailsResult.Emails[i].Email)
|
||||
assert.Equal(t, currentUser.Emails[i].Type, sameEmailsResult.Emails[i].Type)
|
||||
assert.Equal(t, currentUser.Emails[i].Primary, sameEmailsResult.Emails[i].Primary)
|
||||
}
|
||||
|
||||
// Test validation for multiple primary emails
|
||||
multiPrimaryUser := fleet.ScimUser{
|
||||
ID: user.ID,
|
||||
UserName: "multi-primary@example.com",
|
||||
ExternalID: ptr.String("ext-multi-primary"),
|
||||
GivenName: ptr.String("Multi"),
|
||||
FamilyName: ptr.String("Primary"),
|
||||
Active: ptr.Bool(true),
|
||||
Emails: []fleet.ScimUserEmail{
|
||||
{
|
||||
Email: "primary1@example.com",
|
||||
Primary: ptr.Bool(true), // First primary
|
||||
Type: ptr.String("work"),
|
||||
},
|
||||
{
|
||||
Email: "primary2@example.com",
|
||||
Primary: ptr.Bool(true), // Second primary - should cause validation error
|
||||
Type: ptr.String("home"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// This should fail with a validation error
|
||||
err = ds.ReplaceScimUser(t.Context(), &multiPrimaryUser)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "only one email can be marked as primary")
|
||||
|
||||
// Test email comparison behavior with different combinations of nil/non-nil fields
|
||||
// First, create a user with an email that has all fields set
|
||||
userWithAllFields := fleet.ScimUser{
|
||||
ID: user.ID,
|
||||
UserName: "all-fields@example.com",
|
||||
ExternalID: ptr.String("ext-all-fields"),
|
||||
GivenName: ptr.String("All"),
|
||||
FamilyName: ptr.String("Fields"),
|
||||
Active: ptr.Bool(true),
|
||||
Emails: []fleet.ScimUserEmail{
|
||||
{
|
||||
Email: "all-fields@example.com",
|
||||
Primary: ptr.Bool(true),
|
||||
Type: ptr.String("work"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = ds.ReplaceScimUser(t.Context(), &userWithAllFields)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now create a user with the same email but with nil Primary field
|
||||
userWithNilPrimary := fleet.ScimUser{
|
||||
ID: user.ID,
|
||||
UserName: "all-fields@example.com",
|
||||
ExternalID: ptr.String("ext-all-fields"),
|
||||
GivenName: ptr.String("All"),
|
||||
FamilyName: ptr.String("Fields"),
|
||||
Active: ptr.Bool(true),
|
||||
Emails: []fleet.ScimUserEmail{
|
||||
{
|
||||
Email: "all-fields@example.com",
|
||||
Primary: nil, // Changed from true to nil
|
||||
Type: ptr.String("work"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// This should update the emails since the Primary field changed
|
||||
err = ds.ReplaceScimUser(t.Context(), &userWithNilPrimary)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the email was updated
|
||||
var nilPrimaryUser *fleet.ScimUser
|
||||
nilPrimaryUser, err = ds.ScimUserByID(t.Context(), user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nilPrimaryUser.Emails, 1)
|
||||
assert.Equal(t, "all-fields@example.com", nilPrimaryUser.Emails[0].Email)
|
||||
assert.Nil(t, nilPrimaryUser.Emails[0].Primary, "Primary field should be nil")
|
||||
|
||||
// Now create a user with the same email but with nil Type field
|
||||
userWithNilType := fleet.ScimUser{
|
||||
ID: user.ID,
|
||||
UserName: "all-fields@example.com",
|
||||
ExternalID: ptr.String("ext-all-fields"),
|
||||
GivenName: ptr.String("All"),
|
||||
FamilyName: ptr.String("Fields"),
|
||||
Active: ptr.Bool(true),
|
||||
Emails: []fleet.ScimUserEmail{
|
||||
{
|
||||
Email: "all-fields@example.com",
|
||||
Primary: nil,
|
||||
Type: nil, // Changed from "work" to nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// This should update the emails since the Type field changed
|
||||
err = ds.ReplaceScimUser(t.Context(), &userWithNilType)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the email was updated
|
||||
var nilTypeUser *fleet.ScimUser
|
||||
nilTypeUser, err = ds.ScimUserByID(t.Context(), user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nilTypeUser.Emails, 1)
|
||||
assert.Equal(t, "all-fields@example.com", nilTypeUser.Emails[0].Email)
|
||||
assert.Nil(t, nilTypeUser.Emails[0].Type, "Type field should be nil")
|
||||
}
|
||||
|
||||
func testDeleteScimUser(t *testing.T, ds *Datastore) {
|
||||
// Create a test user
|
||||
user := fleet.ScimUser{
|
||||
|
|
@ -1047,9 +1214,11 @@ func testListScimGroups(t *testing.T, ds *Datastore) {
|
|||
}
|
||||
|
||||
// Test 1: List all groups
|
||||
allGroups, totalResults, err := ds.ListScimGroups(t.Context(), fleet.ScimListOptions{
|
||||
StartIndex: 1,
|
||||
PerPage: 10,
|
||||
allGroups, totalResults, err := ds.ListScimGroups(t.Context(), fleet.ScimGroupsListOptions{
|
||||
ScimListOptions: fleet.ScimListOptions{
|
||||
StartIndex: 1,
|
||||
PerPage: 10,
|
||||
},
|
||||
})
|
||||
require.Nil(t, err)
|
||||
assert.GreaterOrEqual(t, len(allGroups), 3) // There might be other groups from previous tests
|
||||
|
|
@ -1068,18 +1237,22 @@ func testListScimGroups(t *testing.T, ds *Datastore) {
|
|||
assert.Equal(t, 3, foundGroups)
|
||||
|
||||
// Test 2: Pagination - first page with 2 items
|
||||
page1Groups, totalPage1, err := ds.ListScimGroups(t.Context(), fleet.ScimListOptions{
|
||||
StartIndex: 1,
|
||||
PerPage: 2,
|
||||
page1Groups, totalPage1, err := ds.ListScimGroups(t.Context(), fleet.ScimGroupsListOptions{
|
||||
ScimListOptions: fleet.ScimListOptions{
|
||||
StartIndex: 1,
|
||||
PerPage: 2,
|
||||
},
|
||||
})
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 2, len(page1Groups))
|
||||
assert.GreaterOrEqual(t, totalPage1, uint(3)) // Total should be at least 3
|
||||
|
||||
// Test 3: Pagination - second page with 2 items
|
||||
page2Groups, totalPage2, err := ds.ListScimGroups(t.Context(), fleet.ScimListOptions{
|
||||
StartIndex: 3, // StartIndex is 1-based, so for the second page with 2 items per page, we start at index 3
|
||||
PerPage: 2,
|
||||
page2Groups, totalPage2, err := ds.ListScimGroups(t.Context(), fleet.ScimGroupsListOptions{
|
||||
ScimListOptions: fleet.ScimListOptions{
|
||||
StartIndex: 3, // StartIndex is 1-based, so for the second page with 2 items per page, we start at index 3
|
||||
PerPage: 2,
|
||||
},
|
||||
})
|
||||
require.Nil(t, err)
|
||||
assert.GreaterOrEqual(t, len(page2Groups), 1) // At least 1 item on the second page
|
||||
|
|
@ -1091,6 +1264,33 @@ func testListScimGroups(t *testing.T, ds *Datastore) {
|
|||
assert.NotEqual(t, p1Group.ID, p2Group.ID, "Groups should not appear on multiple pages")
|
||||
}
|
||||
}
|
||||
|
||||
// Test 4: Filter by display name
|
||||
displayName := "List Test Group 2"
|
||||
filteredGroups, totalFilteredResults, err := ds.ListScimGroups(t.Context(), fleet.ScimGroupsListOptions{
|
||||
ScimListOptions: fleet.ScimListOptions{
|
||||
StartIndex: 1,
|
||||
PerPage: 10,
|
||||
},
|
||||
DisplayNameFilter: &displayName,
|
||||
})
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 1, len(filteredGroups), "Should find exactly one group with the specified display name")
|
||||
assert.Equal(t, uint(1), totalFilteredResults)
|
||||
assert.Equal(t, displayName, filteredGroups[0].DisplayName)
|
||||
|
||||
// Test 5: Filter by non-existent display name
|
||||
nonExistentName := "Non-Existent Group"
|
||||
emptyResults, totalEmptyResults, err := ds.ListScimGroups(t.Context(), fleet.ScimGroupsListOptions{
|
||||
ScimListOptions: fleet.ScimListOptions{
|
||||
StartIndex: 1,
|
||||
PerPage: 10,
|
||||
},
|
||||
DisplayNameFilter: &nonExistentName,
|
||||
})
|
||||
require.Nil(t, err)
|
||||
assert.Empty(t, emptyResults, "Should find no groups with a non-existent display name")
|
||||
assert.Equal(t, uint(0), totalEmptyResults)
|
||||
}
|
||||
|
||||
func testScimUserCreateValidation(t *testing.T, ds *Datastore) {
|
||||
|
|
@ -1967,3 +2167,60 @@ func forceSetHostProfileStatus(t *testing.T, ds *Datastore, hostUUID string, pro
|
|||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func testScimUsersExist(t *testing.T, ds *Datastore) {
|
||||
// Create test users
|
||||
users := createTestScimUsers(t, ds)
|
||||
userIDs := make([]uint, len(users))
|
||||
for i, user := range users {
|
||||
userIDs[i] = user.ID
|
||||
}
|
||||
|
||||
// Test 1: Empty slice should return true
|
||||
exist, err := ds.ScimUsersExist(t.Context(), []uint{})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exist, "Empty slice should return true")
|
||||
|
||||
// Test 2: All existing users should return true
|
||||
exist, err = ds.ScimUsersExist(t.Context(), userIDs)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exist, "All existing users should return true")
|
||||
|
||||
// Test 3: Mix of existing and non-existing users should return false
|
||||
nonExistentIDs := userIDs
|
||||
nonExistentIDs = append(nonExistentIDs, 99999)
|
||||
exist, err = ds.ScimUsersExist(t.Context(), nonExistentIDs)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exist, "Mix of existing and non-existing users should return false")
|
||||
|
||||
// Test 4: Only non-existing users should return false
|
||||
exist, err = ds.ScimUsersExist(t.Context(), []uint{99999, 100000})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exist, "Only non-existing users should return false")
|
||||
|
||||
// Test 5: Test with a large number of IDs to verify batching works
|
||||
// First, create a large number of test users
|
||||
largeUserIDs := make([]uint, 0, 25000)
|
||||
largeUserIDs = append(largeUserIDs, userIDs...) // Add existing users
|
||||
|
||||
// Add some non-existent IDs to test batching with mixed results
|
||||
for i := 0; i < 24990; i++ {
|
||||
largeUserIDs = append(largeUserIDs, uint(1000000)+uint(i)) // nolint:gosec // dismiss G115 integer overflow
|
||||
}
|
||||
|
||||
exist, err = ds.ScimUsersExist(t.Context(), largeUserIDs)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exist, "Large batch with non-existing users should return false")
|
||||
|
||||
// Test 6: Test with a large number of existing IDs
|
||||
// This is a bit tricky to test thoroughly without creating thousands of users,
|
||||
// so we'll just verify the function handles a large slice without errors
|
||||
largeExistingIDs := make([]uint, 0, 25000)
|
||||
for i := 0; i < 25000; i++ {
|
||||
largeExistingIDs = append(largeExistingIDs, userIDs[i%len(userIDs)])
|
||||
}
|
||||
|
||||
exist, err = ds.ScimUsersExist(t.Context(), largeExistingIDs)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exist, "Large batch with only existing users should return true")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2067,6 +2067,9 @@ type Datastore interface {
|
|||
ScimUserByUserNameOrEmail(ctx context.Context, userName string, email string) (*ScimUser, error)
|
||||
// ScimUserByHostID retrieves a SCIM user associated with a host ID
|
||||
ScimUserByHostID(ctx context.Context, hostID uint) (*ScimUser, error)
|
||||
// ScimUsersExist checks if all the provided SCIM user IDs exist in the datastore
|
||||
// If the slice is empty, it returns true
|
||||
ScimUsersExist(ctx context.Context, ids []uint) (bool, error)
|
||||
// ReplaceScimUser replaces an existing SCIM user in the database
|
||||
ReplaceScimUser(ctx context.Context, user *ScimUser) error
|
||||
// DeleteScimUser deletes a SCIM user from the database
|
||||
|
|
@ -2084,7 +2087,7 @@ type Datastore interface {
|
|||
// DeleteScimGroup deletes a SCIM group from the database
|
||||
DeleteScimGroup(ctx context.Context, id uint) error
|
||||
// ListScimGroups retrieves a list of SCIM groups with pagination
|
||||
ListScimGroups(ctx context.Context, opts ScimListOptions) (groups []ScimGroup, totalResults uint, err error)
|
||||
ListScimGroups(ctx context.Context, opts ScimGroupsListOptions) (groups []ScimGroup, totalResults uint, err error)
|
||||
// ScimLastRequest retrieves the last SCIM request info
|
||||
ScimLastRequest(ctx context.Context) (*ScimLastRequest, error)
|
||||
// UpdateScimLastRequest updates the last SCIM request info
|
||||
|
|
|
|||
|
|
@ -48,6 +48,28 @@ type ScimUserEmail struct {
|
|||
Type *string `db:"type"`
|
||||
}
|
||||
|
||||
// GenerateComparisonKey generates a unique string representation of the email
|
||||
// that can be used for comparison, properly handling nil values.
|
||||
func (e ScimUserEmail) GenerateComparisonKey() string {
|
||||
// Handle Type field which can be nil
|
||||
typeValue := "nil"
|
||||
if e.Type != nil {
|
||||
typeValue = *e.Type
|
||||
}
|
||||
|
||||
// Handle Primary field which can be nil
|
||||
primaryValue := "nil"
|
||||
if e.Primary != nil {
|
||||
if *e.Primary {
|
||||
primaryValue = "true"
|
||||
} else {
|
||||
primaryValue = "false"
|
||||
}
|
||||
}
|
||||
|
||||
return e.Email + ":" + typeValue + ":" + primaryValue
|
||||
}
|
||||
|
||||
type ScimListOptions struct {
|
||||
// 1-based index of the first result to return (must be positive integer)
|
||||
StartIndex uint
|
||||
|
|
@ -69,6 +91,13 @@ type ScimUsersListOptions struct {
|
|||
EmailValueFilter *string
|
||||
}
|
||||
|
||||
type ScimGroupsListOptions struct {
|
||||
ScimListOptions
|
||||
|
||||
// DisplayNameFilter filters by displayName
|
||||
DisplayNameFilter *string
|
||||
}
|
||||
|
||||
type ScimGroup struct {
|
||||
ID uint `db:"id"`
|
||||
ExternalID *string `db:"external_id"`
|
||||
|
|
|
|||
|
|
@ -1320,6 +1320,8 @@ type ScimUserByUserNameOrEmailFunc func(ctx context.Context, userName string, em
|
|||
|
||||
type ScimUserByHostIDFunc func(ctx context.Context, hostID uint) (*fleet.ScimUser, error)
|
||||
|
||||
type ScimUsersExistFunc func(ctx context.Context, ids []uint) (bool, error)
|
||||
|
||||
type ReplaceScimUserFunc func(ctx context.Context, user *fleet.ScimUser) error
|
||||
|
||||
type DeleteScimUserFunc func(ctx context.Context, id uint) error
|
||||
|
|
@ -1336,7 +1338,7 @@ type ReplaceScimGroupFunc func(ctx context.Context, group *fleet.ScimGroup) erro
|
|||
|
||||
type DeleteScimGroupFunc func(ctx context.Context, id uint) error
|
||||
|
||||
type ListScimGroupsFunc func(ctx context.Context, opts fleet.ScimListOptions) (groups []fleet.ScimGroup, totalResults uint, err error)
|
||||
type ListScimGroupsFunc func(ctx context.Context, opts fleet.ScimGroupsListOptions) (groups []fleet.ScimGroup, totalResults uint, err error)
|
||||
|
||||
type ScimLastRequestFunc func(ctx context.Context) (*fleet.ScimLastRequest, error)
|
||||
|
||||
|
|
@ -3290,6 +3292,9 @@ type DataStore struct {
|
|||
ScimUserByHostIDFunc ScimUserByHostIDFunc
|
||||
ScimUserByHostIDFuncInvoked bool
|
||||
|
||||
ScimUsersExistFunc ScimUsersExistFunc
|
||||
ScimUsersExistFuncInvoked bool
|
||||
|
||||
ReplaceScimUserFunc ReplaceScimUserFunc
|
||||
ReplaceScimUserFuncInvoked bool
|
||||
|
||||
|
|
@ -7869,6 +7874,13 @@ func (s *DataStore) ScimUserByHostID(ctx context.Context, hostID uint) (*fleet.S
|
|||
return s.ScimUserByHostIDFunc(ctx, hostID)
|
||||
}
|
||||
|
||||
func (s *DataStore) ScimUsersExist(ctx context.Context, ids []uint) (bool, error) {
|
||||
s.mu.Lock()
|
||||
s.ScimUsersExistFuncInvoked = true
|
||||
s.mu.Unlock()
|
||||
return s.ScimUsersExistFunc(ctx, ids)
|
||||
}
|
||||
|
||||
func (s *DataStore) ReplaceScimUser(ctx context.Context, user *fleet.ScimUser) error {
|
||||
s.mu.Lock()
|
||||
s.ReplaceScimUserFuncInvoked = true
|
||||
|
|
@ -7925,7 +7937,7 @@ func (s *DataStore) DeleteScimGroup(ctx context.Context, id uint) error {
|
|||
return s.DeleteScimGroupFunc(ctx, id)
|
||||
}
|
||||
|
||||
func (s *DataStore) ListScimGroups(ctx context.Context, opts fleet.ScimListOptions) (groups []fleet.ScimGroup, totalResults uint, err error) {
|
||||
func (s *DataStore) ListScimGroups(ctx context.Context, opts fleet.ScimGroupsListOptions) (groups []fleet.ScimGroup, totalResults uint, err error) {
|
||||
s.mu.Lock()
|
||||
s.ListScimGroupsFuncInvoked = true
|
||||
s.mu.Unlock()
|
||||
|
|
|
|||
Loading…
Reference in a new issue