fleet/ee/server/scim/users.go
Magnus Jensen 4e56d53ccd
Require given and fullname from SCIM (#33336)
<!-- Add the related story/sub-task/bug number, like Resolves #123, or
remove if NA -->
**Related issue:** Followup fix for #30888

See
https://github.com/fleetdm/fleet/issues/30888#issuecomment-3321700108

Needs to be cherry-picked into 4.74

# Checklist for submitter

If some of the following don't apply, delete the relevant line.


## Testing

- [x] Added/updated automated tests
- [x] QA'd all new/changed functionality manually
2025-09-23 17:09:42 +03:00

1046 lines
34 KiB
Go

package scim
import (
"fmt"
"net/http"
"net/url"
"slices"
"strconv"
"strings"
"unicode"
"github.com/elimity-com/scim"
"github.com/elimity-com/scim/errors"
"github.com/elimity-com/scim/optional"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
kitlog "github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/scim2/filter-parser/v2"
)
const (
// Common attributes: https://datatracker.ietf.org/doc/html/rfc7643#section-3.1
externalIdAttr = "externalId"
// User attributes: https://datatracker.ietf.org/doc/html/rfc7643#section-4.1
userNameAttr = "userName"
nameAttr = "name"
givenNameAttr = "givenName"
familyNameAttr = "familyName"
activeAttr = "active"
emailsAttr = "emails"
groupsAttr = "groups"
valueAttr = "value"
typeAttr = "type"
primaryAttr = "primary"
extensionEnterpriseUserAttributes = "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
departmentAttr = "department"
)
type UserHandler struct {
ds fleet.Datastore
logger kitlog.Logger
}
// Compile-time check
var _ scim.ResourceHandler = &UserHandler{}
func NewUserHandler(ds fleet.Datastore, logger kitlog.Logger) scim.ResourceHandler {
return &UserHandler{ds: ds, logger: logger}
}
func (u *UserHandler) Create(r *http.Request, attributes scim.ResourceAttributes) (scim.Resource, error) {
// Check for userName uniqueness
userName, err := getRequiredResource[string](attributes, userNameAttr)
if err != nil {
level.Error(u.logger).Log("msg", "failed to get userName", "err", err)
return scim.Resource{}, err
}
// In IETF documents, “non-empty” is generally used in the literal sense of “having at least one character.” That means if a value contains one or more spaces (and nothing else), it is still considered non-empty.
if len(userName) == 0 {
level.Info(u.logger).Log("msg", "userName is empty")
return scim.Resource{}, errors.ScimErrorBadParams([]string{userNameAttr})
}
_, err = u.ds.ScimUserByUserName(r.Context(), userName)
switch {
case err != nil && !fleet.IsNotFound(err):
level.Error(u.logger).Log("msg", "failed to check for userName uniqueness", userNameAttr, userName, "err", err)
return scim.Resource{}, err
case err == nil:
level.Info(u.logger).Log("msg", "user already exists", userNameAttr, userName)
return scim.Resource{}, errors.ScimErrorUniqueness
}
user, err := u.createUserFromAttributes(attributes)
if err != nil {
level.Error(u.logger).Log("msg", "failed to create user from attributes", userNameAttr, userName, "err", err)
return scim.Resource{}, err
}
user.ID, err = u.ds.CreateScimUser(r.Context(), user)
if err != nil {
return scim.Resource{}, err
}
return createUserResource(user), nil
}
func (u *UserHandler) createUserFromAttributes(attributes scim.ResourceAttributes) (*fleet.ScimUser, error) {
user := fleet.ScimUser{}
var err error
user.UserName, err = getRequiredResource[string](attributes, userNameAttr)
if err != nil {
return nil, err
}
user.ExternalID, err = getOptionalResource[string](attributes, externalIdAttr)
if err != nil {
return nil, err
}
user.Active, err = getOptionalResource[bool](attributes, activeAttr)
if err != nil {
return nil, err
}
name, err := getComplexResource(attributes, nameAttr)
if err != nil {
return nil, err
}
user.FamilyName, err = getOptionalResource[string](name, familyNameAttr)
if err != nil {
return nil, err
}
if user.FamilyName == nil || len(*user.FamilyName) == 0 {
return nil, errors.ScimErrorInvalidValue // Disallow non set field and empty value
}
user.GivenName, err = getOptionalResource[string](name, givenNameAttr)
if err != nil {
return nil, err
}
if user.GivenName == nil || len(*user.GivenName) == 0 {
return nil, errors.ScimErrorInvalidValue // Disallow non set field and empty value
}
emails, err := getComplexResourceSlice(attributes, emailsAttr)
if err != nil {
return nil, err
}
userEmails := make([]fleet.ScimUserEmail, 0, len(emails))
for _, email := range emails {
userEmail := fleet.ScimUserEmail{}
userEmail.Email, err = getRequiredResource[string](email, valueAttr)
if err != nil {
return nil, err
}
// Service providers SHOULD canonicalize the value according to [RFC5321]
// https://datatracker.ietf.org/doc/html/rfc7643#section-4.1.2
userEmail.Email, err = normalizeEmail(userEmail.Email)
if err != nil {
return nil, errors.ScimErrorBadParams([]string{valueAttr})
}
userEmail.Type, err = getOptionalResource[string](email, typeAttr)
if err != nil {
return nil, err
}
userEmail.Primary, err = getOptionalResource[bool](email, primaryAttr)
if err != nil {
return nil, err
}
userEmails = append(userEmails, userEmail)
}
user.Emails = userEmails
// Attempt to get extension enterprise user attributes.
extendedAttributes := u.getExtensionEnterpriseUserAttributes(user.UserName, attributes)
user.Department = extendedAttributes.department
return &user, nil
}
type extendedAttributes struct {
department *string
}
func (u *UserHandler) getExtensionEnterpriseUserAttributes(userName string, attributes scim.ResourceAttributes) extendedAttributes {
var attrs extendedAttributes
m_, ok := attributes[extensionEnterpriseUserAttributes]
if !ok {
return attrs
}
m, ok := m_.(map[string]any)
if !ok {
level.Error(u.logger).Log(
"msg", fmt.Sprintf("unexpected type for %s: %T", extensionEnterpriseUserAttributes, m_),
userNameAttr, userName,
)
return attrs
}
// Attempt to get department attribute.
if department_, ok := m[departmentAttr]; ok {
if department, ok := department_.(string); ok {
attrs.department = &department
} else {
level.Error(u.logger).Log(
"msg", fmt.Sprintf("unexpected type for %s.department: %T", extensionEnterpriseUserAttributes, department_),
userNameAttr, userName,
)
}
}
return attrs
}
func getRequiredResource[T string | bool](attributes scim.ResourceAttributes, key string) (T, error) {
var val T
valIntf, ok := attributes[key]
if !ok || valIntf == nil {
return val, errors.ScimErrorBadParams([]string{key})
}
val, ok = valIntf.(T)
if !ok {
return val, errors.ScimErrorBadParams([]string{key})
}
return val, nil
}
func getOptionalResource[T string | bool](attributes scim.ResourceAttributes, key string) (*T, error) {
var valPtr *T
valIntf, ok := attributes[key]
if ok && valIntf != nil {
val, ok := valIntf.(T)
if !ok {
return nil, errors.ScimErrorBadParams([]string{key})
}
valPtr = &val
}
return valPtr, nil
}
func getComplexResource(attributes scim.ResourceAttributes, key string) (map[string]interface{}, error) {
valIntf, ok := attributes[key]
if ok && valIntf != nil {
val, ok := valIntf.(map[string]interface{})
if !ok {
return nil, errors.ScimErrorBadParams([]string{key})
}
return val, nil
}
return nil, nil
}
func getComplexResourceSlice(attributes scim.ResourceAttributes, key string) ([]map[string]interface{}, error) {
valIntf, ok := attributes[key]
if ok && valIntf != nil {
valSliceIntf, ok := valIntf.([]interface{})
if !ok {
return nil, errors.ScimErrorBadParams([]string{key})
}
val := make([]map[string]interface{}, 0, len(valSliceIntf))
for _, v := range valSliceIntf {
valMap, ok := v.(map[string]interface{})
if !ok {
return nil, errors.ScimErrorBadParams([]string{key})
}
if len(valMap) > 0 {
val = append(val, valMap)
}
}
return val, nil
}
return nil, nil
}
func (u *UserHandler) Get(r *http.Request, id string) (scim.Resource, error) {
idUint, err := extractUserIDFromValue(id)
if err != nil {
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
user, err := u.ds.ScimUserByID(r.Context(), idUint)
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to get user", "id", id, "err", err)
return scim.Resource{}, err
}
return createUserResource(user), nil
}
func createUserResource(user *fleet.ScimUser) scim.Resource {
userResource := scim.Resource{}
userResource.ID = scimUserID(user.ID)
if user.ExternalID != nil {
userResource.ExternalID = optional.NewString(*user.ExternalID)
}
userResource.Attributes = scim.ResourceAttributes{}
userResource.Attributes[userNameAttr] = user.UserName
if user.Active != nil {
userResource.Attributes[activeAttr] = *user.Active
}
if user.FamilyName != nil || user.GivenName != nil {
userResource.Attributes[nameAttr] = make(scim.ResourceAttributes)
if user.FamilyName != nil {
userResource.Attributes[nameAttr].(scim.ResourceAttributes)[familyNameAttr] = *user.FamilyName
}
if user.GivenName != nil {
userResource.Attributes[nameAttr].(scim.ResourceAttributes)[givenNameAttr] = *user.GivenName
}
}
if len(user.Emails) > 0 {
emails := make([]scim.ResourceAttributes, 0, len(user.Emails))
for _, email := range user.Emails {
emailResource := make(scim.ResourceAttributes)
emailResource[valueAttr] = email.Email
if email.Type != nil {
emailResource[typeAttr] = *email.Type
}
if email.Primary != nil {
emailResource[primaryAttr] = *email.Primary
}
emails = append(emails, emailResource)
}
userResource.Attributes[emailsAttr] = emails
}
if len(user.Groups) > 0 {
groups := make([]scim.ResourceAttributes, 0, len(user.Groups))
for _, group := range user.Groups {
groups = append(groups, map[string]interface{}{
valueAttr: scimGroupID(group.ID),
"$ref": "Groups/" + scimGroupID(group.ID),
"display": group.DisplayName,
})
}
userResource.Attributes[groupsAttr] = groups
}
if user.Department != nil {
extensionEnterpriseUserAttributesMap := make(scim.ResourceAttributes)
extensionEnterpriseUserAttributesMap[departmentAttr] = *user.Department
userResource.Attributes[extensionEnterpriseUserAttributes] = extensionEnterpriseUserAttributesMap
}
return userResource
}
// GetAll
// Pagination is 1-indexed on the startIndex. The startIndex is the index of the resource (not the index of the page), per RFC7644.
//
// Per RFC7644 3.4.2, SHOULD ignore any query parameters they do not recognize instead of rejecting the query for versioning compatibility reasons
// https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.2
//
// Providers MUST decline to filter results if the specified filter operation is not recognized and return an HTTP 400 error with a
// "scimType" error of "invalidFilter" and an appropriate human-readable response as per Section 3.12. For example, if a client specified an
// unsupported operator named 'regex', the service provider should specify an error response description identifying the client error,
// e.g., 'The operator 'regex' is not supported.'
//
// If a SCIM service provider determines that too many results would be returned the server base URI, the server SHALL reject the request by
// returning an HTTP response with HTTP status code 400 (Bad Request) and JSON attribute "scimType" set to "tooMany" (see Table 9).
//
// totalResults: The total number of results returned by the list or query operation. The value may be larger than the number of
// resources returned, such as when returning a single page (see Section 3.4.2.4) of results where multiple pages are available.
func (u *UserHandler) GetAll(r *http.Request, params scim.ListRequestParams) (scim.Page, error) {
startIndex := params.StartIndex
if startIndex < 1 {
startIndex = 1
}
count := params.Count
if count > maxResults {
return scim.Page{}, errors.ScimErrorTooMany
}
if count < 1 {
count = maxResults
}
opts := fleet.ScimUsersListOptions{
ScimListOptions: fleet.ScimListOptions{
StartIndex: uint(startIndex), // nolint:gosec // ignore G115
PerPage: uint(count), // nolint:gosec // ignore G115
},
}
resourceFilter := r.URL.Query().Get("filter")
if resourceFilter != "" {
expr, err := filter.ParseAttrExp([]byte(resourceFilter))
if err != nil {
level.Error(u.logger).Log("msg", "failed to parse filter", "filter", resourceFilter, "err", err)
return scim.Page{}, errors.ScimErrorInvalidFilter
}
if !strings.EqualFold(expr.AttributePath.String(), "userName") || expr.Operator != "eq" {
level.Info(u.logger).Log("msg", "unsupported filter", "filter", resourceFilter)
return scim.Page{}, nil
}
userName, ok := expr.CompareValue.(string)
if !ok {
level.Error(u.logger).Log("msg", "unsupported value", "value", expr.CompareValue)
return scim.Page{}, nil
}
// Decode URL-encoded characters in userName, which is required to pass Microsoft Entra ID SCIM Validator
userName, err = url.QueryUnescape(userName)
if err != nil {
level.Error(u.logger).Log("msg", "failed to decode userName", "userName", userName, "err", err)
return scim.Page{}, nil
}
opts.UserNameFilter = &userName
}
users, totalResults, err := u.ds.ListScimUsers(r.Context(), opts)
if err != nil {
level.Error(u.logger).Log("msg", "failed to list users", "err", err)
return scim.Page{}, err
}
result := scim.Page{
TotalResults: int(totalResults), // nolint:gosec // ignore G115
Resources: make([]scim.Resource, 0, len(users)),
}
for _, user := range users {
result.Resources = append(result.Resources, createUserResource(&user))
}
return result, nil
}
func (u *UserHandler) Replace(r *http.Request, id string, attributes scim.ResourceAttributes) (scim.Resource, error) {
idUint, err := extractUserIDFromValue(id)
if err != nil {
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
user, err := u.createUserFromAttributes(attributes)
if err != nil {
level.Error(u.logger).Log("msg", "failed to create user from attributes", "id", id, "err", err)
return scim.Resource{}, err
}
user.ID = idUint
// Username is unique, so we must check if another user already exists with that username to return a clear error
userWithSameUsername, err := u.ds.ScimUserByUserName(r.Context(), user.UserName)
switch {
case err != nil && !fleet.IsNotFound(err):
level.Error(u.logger).Log("msg", "failed to check for userName uniqueness", userNameAttr, user.UserName, "err", err)
return scim.Resource{}, err
case err == nil && user.ID != userWithSameUsername.ID:
level.Info(u.logger).Log("msg", "user already exists with this username", userNameAttr, user.UserName)
return scim.Resource{}, errors.ScimErrorUniqueness
// Otherwise, we assume that we are replacing the username with this operation.
}
err = u.ds.ReplaceScimUser(r.Context(), user)
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user to replace", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to replace user", "id", id, "err", err)
return scim.Resource{}, err
}
return createUserResource(user), nil
}
// Delete
// https://datatracker.ietf.org/doc/html/rfc7644#section-3.6
// MUST return a 404 (Not Found) error code for all operations associated with the previously deleted resource
func (u *UserHandler) Delete(r *http.Request, id string) error {
idUint, err := extractUserIDFromValue(id)
if err != nil {
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
return errors.ScimErrorResourceNotFound(id)
}
err = u.ds.DeleteScimUser(r.Context(), idUint)
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user to delete", "id", id)
return errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to delete user", "id", id, "err", err)
return err
}
return nil
}
// Patch - https://datatracker.ietf.org/doc/html/rfc7644#section-3.5.2
func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchOperation) (scim.Resource, error) {
idUint, err := extractUserIDFromValue(id)
if err != nil {
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
user, err := u.ds.ScimUserByID(r.Context(), idUint)
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user to patch", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to get user to patch", "id", id, "err", err)
return scim.Resource{}, err
}
for _, op := range operations {
if op.Op != scim.PatchOperationAdd && op.Op != scim.PatchOperationReplace && op.Op != scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "unsupported patch operation", "op", op.Op)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
switch {
// If path is not specified, we look for the path in the value attribute.
case op.Path == nil:
if op.Op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
newValues, ok := op.Value.(map[string]interface{})
if !ok {
level.Info(u.logger).Log("msg", "unsupported patch value", "value", op.Value)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
for k, v := range newValues {
switch k {
case externalIdAttr:
err = u.patchExternalId(op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
case userNameAttr:
err = u.patchUserName(op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
case activeAttr:
err = u.patchActive(op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
case nameAttr + "." + givenNameAttr:
err = u.patchGivenName(op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
case nameAttr + "." + familyNameAttr:
err = u.patchFamilyName(op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
case nameAttr:
err = u.patchName(v, op, user)
if err != nil {
return scim.Resource{}, err
}
case emailsAttr:
err = u.patchEmails(v, op, user)
if err != nil {
return scim.Resource{}, err
}
case extensionEnterpriseUserAttributes + ":" + departmentAttr:
err = u.patchDepartment(op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
default:
level.Info(u.logger).Log("msg", "unsupported patch value field", "field", k)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
}
case op.Path.String() == externalIdAttr:
err = u.patchExternalId(op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == userNameAttr:
err = u.patchUserName(op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == activeAttr:
err = u.patchActive(op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == nameAttr+"."+givenNameAttr:
err = u.patchGivenName(op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == nameAttr+"."+familyNameAttr:
err = u.patchFamilyName(op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == nameAttr:
err = u.patchName(op.Value, op, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == emailsAttr:
err = u.patchEmails(op.Value, op, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.AttributePath.String() == emailsAttr:
err = u.patchEmailsWithPathFiltering(op, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.AttributePath.String() == extensionEnterpriseUserAttributes+":"+departmentAttr:
err = u.patchDepartment(op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
default:
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
}
if len(operations) != 0 {
err = u.ds.ReplaceScimUser(r.Context(), user)
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user to patch", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to patch user", "id", id, "err", err)
return scim.Resource{}, err
}
}
return createUserResource(user), nil
}
func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user *fleet.ScimUser) error {
emailType, err := u.getEmailType(op)
if err != nil {
return err
}
emailFound := false
var emailIndex int
for i, email := range user.Emails {
if email.Type != nil && *email.Type == emailType {
emailIndex = i
emailFound = true
break
}
}
if !emailFound && op.Op != scim.PatchOperationAdd {
level.Info(u.logger).Log("msg", "email not found", "email_type", emailType, "op", fmt.Sprintf("%v", op))
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
if op.Path.SubAttribute == nil {
if op.Op == scim.PatchOperationRemove {
user.Emails = slices.Delete(user.Emails, emailIndex, emailIndex+1)
return nil
}
// For add and replace operations, we need to extract the emails
var emailsList []interface{}
// Handle different value formats
switch val := op.Value.(type) {
case []interface{}:
// Direct array of members
emailsList = val
case map[string]interface{}:
// Single member as a map
emailsList = []interface{}{val}
default:
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
switch op.Op {
case scim.PatchOperationReplace:
if len(emailsList) == 0 {
user.Emails = slices.Delete(user.Emails, emailIndex, emailIndex+1)
return nil
}
if len(emailsList) != 1 {
level.Info(u.logger).Log("msg", "only 1 email should be present for replacement", "emails", emailsList)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
userEmail, err := u.extractEmail(emailsList[0], op)
if err != nil {
return err
}
// If setting primary to true, then unset true from other emails
if userEmail.Primary != nil && *userEmail.Primary {
clearPrimaryFlagFromEmails(user)
}
user.Emails[emailIndex] = userEmail
case scim.PatchOperationAdd:
if len(emailsList) == 0 {
level.Info(u.logger).Log("msg", "no emails provided to add", "emails", emailsList)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
var newEmails []fleet.ScimUserEmail
for e := range emailsList {
userEmail, err := u.extractEmail(emailsList[e], op)
if err != nil {
return err
}
userEmail.Type = &emailType
newEmails = append(newEmails, userEmail)
}
primaryExists, err := u.checkEmailPrimary(newEmails)
if err != nil {
return err
}
if primaryExists {
clearPrimaryFlagFromEmails(user)
}
user.Emails = append(user.Emails, newEmails...)
}
return nil
}
if op.Op == scim.PatchOperationAdd && !emailFound {
user.Emails = append(user.Emails, fleet.ScimUserEmail{
Type: ptr.String(emailType),
})
emailIndex = len(user.Emails) - 1
}
switch *op.Path.SubAttribute {
case primaryAttr:
if op.Op == scim.PatchOperationRemove {
user.Emails[emailIndex].Primary = nil
return nil
}
if op.Value == nil {
user.Emails[emailIndex].Primary = nil
return nil
}
primary, err := getConcreteType[bool](u, op.Value, primaryAttr)
if err != nil {
return err
}
// If setting primary to true, then unset true from other emails
if primary {
clearPrimaryFlagFromEmails(user)
}
user.Emails[emailIndex].Primary = &primary
case valueAttr:
if op.Op == scim.PatchOperationRemove {
// The operation of removing an email value doesn't make sense, but we allow it.
user.Emails[emailIndex].Email = ""
return nil
}
value, err := getConcreteType[string](u, op.Value, valueAttr)
if err != nil {
return err
}
user.Emails[emailIndex].Email = value
case typeAttr:
if op.Op == scim.PatchOperationRemove {
user.Emails[emailIndex].Type = nil
return nil
}
if op.Value == nil {
user.Emails[emailIndex].Type = nil
return nil
}
newEmailType, err := getConcreteType[string](u, op.Value, typeAttr)
if err != nil {
return err
}
user.Emails[emailIndex].Type = &newEmailType
default:
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
return nil
}
func (u *UserHandler) getEmailType(op scim.PatchOperation) (string, error) {
attrExpression, ok := op.Path.ValueExpression.(*filter.AttributeExpression)
if !ok {
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
return "", errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// Only matching by email type (work, etc.) is supported.
if attrExpression.AttributePath.String() != typeAttr || attrExpression.Operator != filter.EQ {
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path, "expression", attrExpression.AttributePath.String())
return "", errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
emailType, ok := attrExpression.CompareValue.(string)
if !ok {
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path, "compare_value", attrExpression.CompareValue)
return "", errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
return emailType, nil
}
func getConcreteType[T string | bool](u *UserHandler, v interface{}, name string) (T, error) {
concreteType, ok := v.(T)
if !ok {
var zeroValue T
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' value", name), "value", v)
return zeroValue, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
}
return concreteType, nil
}
func (u *UserHandler) patchFamilyName(op string, v interface{}, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr+"."+familyNameAttr)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
familyName, err := getConcreteType[string](u, v, nameAttr+"."+familyNameAttr)
if err != nil {
return err
}
user.FamilyName = &familyName
return nil
}
func (u *UserHandler) patchGivenName(op string, v interface{}, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr+"."+givenNameAttr)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
givenName, err := getConcreteType[string](u, v, nameAttr+"."+givenNameAttr)
if err != nil {
return err
}
user.GivenName = &givenName
return nil
}
func (u *UserHandler) patchActive(op string, v interface{}, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove || v == nil {
user.Active = nil
return nil
}
active, err := getConcreteType[bool](u, v, activeAttr)
if err != nil {
return err
}
user.Active = &active
return nil
}
func (u *UserHandler) patchExternalId(op string, v interface{}, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove || v == nil {
user.ExternalID = nil
return nil
}
externalId, err := getConcreteType[string](u, v, externalIdAttr)
if err != nil {
return err
}
user.ExternalID = ptr.String(externalId)
return nil
}
func (u *UserHandler) patchUserName(op string, v interface{}, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", userNameAttr)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
userName, err := getConcreteType[string](u, v, userNameAttr)
if err != nil {
return err
}
if userName == "" {
level.Info(u.logger).Log("msg", fmt.Sprintf("'%s' cannot be empty", userNameAttr), "value", v)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
}
user.UserName = userName
return nil
}
func (u *UserHandler) patchDepartment(op string, v interface{}, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove || v == nil {
user.Department = nil
return nil
}
department, err := getConcreteType[string](u, v, departmentAttr)
if err != nil {
return err
}
user.Department = &department
return nil
}
func clearPrimaryFlagFromEmails(user *fleet.ScimUser) {
for i, email := range user.Emails {
if email.Primary != nil && *email.Primary {
user.Emails[i].Primary = ptr.Bool(false)
}
}
}
func (u *UserHandler) patchEmails(v interface{}, op scim.PatchOperation, user *fleet.ScimUser) error {
if op.Op == scim.PatchOperationRemove {
user.Emails = nil
return nil
}
// For add and replace operations, we need to extract the emails
var emailsList []interface{}
// Handle different value formats
switch val := v.(type) {
case []interface{}:
// Direct array of members
emailsList = val
case map[string]interface{}:
// Single member as a map
emailsList = []interface{}{val}
default:
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
if op.Op == scim.PatchOperationAdd && len(emailsList) == 0 {
level.Info(u.logger).Log("msg", "no emails provided to add", "emails", emailsList)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// Convert the emails to the expected format
userEmails := make([]fleet.ScimUserEmail, 0, len(emailsList))
for _, emailIntf := range emailsList {
userEmail, err := u.extractEmail(emailIntf, op)
if err != nil {
return err
}
userEmails = append(userEmails, userEmail)
}
primaryExists, err := u.checkEmailPrimary(userEmails)
if err != nil {
return err
}
if op.Op == scim.PatchOperationAdd {
if primaryExists {
// Clear the primary flag from current emails because we are merging the two email lists and a new email has that flag.
clearPrimaryFlagFromEmails(user)
}
userEmails = append(user.Emails, userEmails...)
}
user.Emails = userEmails
return nil
}
// checkEmailPrimary ensures at most one email is marked as primary
func (u *UserHandler) checkEmailPrimary(userEmails []fleet.ScimUserEmail) (bool, error) {
primaryEmailCount := 0
for _, email := range userEmails {
if email.Primary != nil && *email.Primary {
primaryEmailCount++
if primaryEmailCount > 1 {
level.Info(u.logger).Log("msg", "multiple primary emails found")
return false, errors.ScimErrorBadParams([]string{"Only one email can be marked as primary"})
}
}
}
return primaryEmailCount > 0, nil
}
func (u *UserHandler) extractEmail(emailIntf interface{}, op scim.PatchOperation) (fleet.ScimUserEmail, error) {
emailMap, ok := emailIntf.(map[string]interface{})
if !ok {
level.Info(u.logger).Log("msg", "email is not a map", "email", emailIntf)
return fleet.ScimUserEmail{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// Extract the email value (required)
emailValue, ok := emailMap[valueAttr].(string)
if !ok || emailValue == "" {
level.Info(u.logger).Log("msg", "email value is missing or invalid", "email", emailMap)
return fleet.ScimUserEmail{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// Normalize the email
normalizedEmail, err := normalizeEmail(emailValue)
if err != nil {
level.Info(u.logger).Log("msg", "failed to normalize email", "email", emailValue, "err", err)
return fleet.ScimUserEmail{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// Create the email object
userEmail := fleet.ScimUserEmail{
Email: normalizedEmail,
Type: nil,
Primary: nil,
}
// Extract the type (optional)
if typeValue, ok := emailMap[typeAttr].(string); ok {
userEmail.Type = &typeValue
}
// Extract the primary flag (optional)
if primaryValue, ok := emailMap[primaryAttr].(bool); ok {
userEmail.Primary = &primaryValue
}
return userEmail, nil
}
func (u *UserHandler) patchName(v interface{}, op scim.PatchOperation, user *fleet.ScimUser) error {
if op.Op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
name, ok := v.(map[string]interface{})
if !ok {
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", nameAttr), "value", op.Value)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
for nameKey, nameValue := range name {
switch nameKey {
case givenNameAttr:
givenName, ok := nameValue.(string)
if !ok {
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", nameAttr+"."+givenNameAttr), "value",
op.Value)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
user.GivenName = &givenName
case familyNameAttr:
familyName, ok := nameValue.(string)
if !ok {
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", nameAttr+"."+familyNameAttr), "value",
op.Value)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
user.FamilyName = &familyName
default:
level.Info(u.logger).Log("msg", "unsupported patch value field", "field", nameAttr+"."+nameKey)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
}
return nil
}
// normalizeEmail
// The local-part of a mailbox MUST BE treated as case sensitive.
// Mailbox domains follow normal DNS rules and are hence not case sensitive.
// https://datatracker.ietf.org/doc/html/rfc5321#section-2.4
func normalizeEmail(email string) (string, error) {
email = removeWhitespace(email)
emailParts := strings.SplitN(email, "@", 2)
if len(emailParts) != 2 {
return "", fmt.Errorf("invalid email %s", email)
}
emailParts[1] = strings.ToLower(emailParts[1])
return strings.Join(emailParts, "@"), nil
}
func removeWhitespace(str string) string {
return strings.Map(func(r rune) rune {
if unicode.IsSpace(r) {
return -1
}
return r
}, str)
}
func scimUserID(userID uint) string {
return fmt.Sprintf("%d", userID)
}
// extractUserIDFromValue extracts the user ID from a value like "123"
func extractUserIDFromValue(value string) (uint, error) {
id, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return 0, err
}
return uint(id), nil
}