mirror of
https://github.com/fleetdm/fleet
synced 2026-05-06 06:48:54 +00:00
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Followup fix for #30888 See https://github.com/fleetdm/fleet/issues/30888#issuecomment-3321700108 Needs to be cherry-picked into 4.74 # Checklist for submitter If some of the following don't apply, delete the relevant line. ## Testing - [x] Added/updated automated tests - [x] QA'd all new/changed functionality manually
1046 lines
34 KiB
Go
1046 lines
34 KiB
Go
package scim
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"unicode"
|
|
|
|
"github.com/elimity-com/scim"
|
|
"github.com/elimity-com/scim/errors"
|
|
"github.com/elimity-com/scim/optional"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
kitlog "github.com/go-kit/log"
|
|
"github.com/go-kit/log/level"
|
|
"github.com/scim2/filter-parser/v2"
|
|
)
|
|
|
|
const (
|
|
// Common attributes: https://datatracker.ietf.org/doc/html/rfc7643#section-3.1
|
|
externalIdAttr = "externalId"
|
|
|
|
// User attributes: https://datatracker.ietf.org/doc/html/rfc7643#section-4.1
|
|
userNameAttr = "userName"
|
|
nameAttr = "name"
|
|
givenNameAttr = "givenName"
|
|
familyNameAttr = "familyName"
|
|
activeAttr = "active"
|
|
emailsAttr = "emails"
|
|
groupsAttr = "groups"
|
|
valueAttr = "value"
|
|
typeAttr = "type"
|
|
primaryAttr = "primary"
|
|
|
|
extensionEnterpriseUserAttributes = "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
|
departmentAttr = "department"
|
|
)
|
|
|
|
type UserHandler struct {
|
|
ds fleet.Datastore
|
|
logger kitlog.Logger
|
|
}
|
|
|
|
// Compile-time check
|
|
var _ scim.ResourceHandler = &UserHandler{}
|
|
|
|
func NewUserHandler(ds fleet.Datastore, logger kitlog.Logger) scim.ResourceHandler {
|
|
return &UserHandler{ds: ds, logger: logger}
|
|
}
|
|
|
|
func (u *UserHandler) Create(r *http.Request, attributes scim.ResourceAttributes) (scim.Resource, error) {
|
|
// Check for userName uniqueness
|
|
userName, err := getRequiredResource[string](attributes, userNameAttr)
|
|
if err != nil {
|
|
level.Error(u.logger).Log("msg", "failed to get userName", "err", err)
|
|
return scim.Resource{}, err
|
|
}
|
|
// In IETF documents, “non-empty” is generally used in the literal sense of “having at least one character.” That means if a value contains one or more spaces (and nothing else), it is still considered non-empty.
|
|
if len(userName) == 0 {
|
|
level.Info(u.logger).Log("msg", "userName is empty")
|
|
return scim.Resource{}, errors.ScimErrorBadParams([]string{userNameAttr})
|
|
}
|
|
_, err = u.ds.ScimUserByUserName(r.Context(), userName)
|
|
switch {
|
|
case err != nil && !fleet.IsNotFound(err):
|
|
level.Error(u.logger).Log("msg", "failed to check for userName uniqueness", userNameAttr, userName, "err", err)
|
|
return scim.Resource{}, err
|
|
case err == nil:
|
|
level.Info(u.logger).Log("msg", "user already exists", userNameAttr, userName)
|
|
return scim.Resource{}, errors.ScimErrorUniqueness
|
|
}
|
|
|
|
user, err := u.createUserFromAttributes(attributes)
|
|
if err != nil {
|
|
level.Error(u.logger).Log("msg", "failed to create user from attributes", userNameAttr, userName, "err", err)
|
|
return scim.Resource{}, err
|
|
}
|
|
user.ID, err = u.ds.CreateScimUser(r.Context(), user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
|
|
return createUserResource(user), nil
|
|
}
|
|
|
|
func (u *UserHandler) createUserFromAttributes(attributes scim.ResourceAttributes) (*fleet.ScimUser, error) {
|
|
user := fleet.ScimUser{}
|
|
var err error
|
|
user.UserName, err = getRequiredResource[string](attributes, userNameAttr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
user.ExternalID, err = getOptionalResource[string](attributes, externalIdAttr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
user.Active, err = getOptionalResource[bool](attributes, activeAttr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
name, err := getComplexResource(attributes, nameAttr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
user.FamilyName, err = getOptionalResource[string](name, familyNameAttr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if user.FamilyName == nil || len(*user.FamilyName) == 0 {
|
|
return nil, errors.ScimErrorInvalidValue // Disallow non set field and empty value
|
|
}
|
|
|
|
user.GivenName, err = getOptionalResource[string](name, givenNameAttr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if user.GivenName == nil || len(*user.GivenName) == 0 {
|
|
return nil, errors.ScimErrorInvalidValue // Disallow non set field and empty value
|
|
}
|
|
emails, err := getComplexResourceSlice(attributes, emailsAttr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
userEmails := make([]fleet.ScimUserEmail, 0, len(emails))
|
|
for _, email := range emails {
|
|
userEmail := fleet.ScimUserEmail{}
|
|
userEmail.Email, err = getRequiredResource[string](email, valueAttr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Service providers SHOULD canonicalize the value according to [RFC5321]
|
|
// https://datatracker.ietf.org/doc/html/rfc7643#section-4.1.2
|
|
userEmail.Email, err = normalizeEmail(userEmail.Email)
|
|
if err != nil {
|
|
return nil, errors.ScimErrorBadParams([]string{valueAttr})
|
|
}
|
|
userEmail.Type, err = getOptionalResource[string](email, typeAttr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
userEmail.Primary, err = getOptionalResource[bool](email, primaryAttr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
userEmails = append(userEmails, userEmail)
|
|
}
|
|
user.Emails = userEmails
|
|
|
|
// Attempt to get extension enterprise user attributes.
|
|
extendedAttributes := u.getExtensionEnterpriseUserAttributes(user.UserName, attributes)
|
|
user.Department = extendedAttributes.department
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
type extendedAttributes struct {
|
|
department *string
|
|
}
|
|
|
|
func (u *UserHandler) getExtensionEnterpriseUserAttributes(userName string, attributes scim.ResourceAttributes) extendedAttributes {
|
|
var attrs extendedAttributes
|
|
m_, ok := attributes[extensionEnterpriseUserAttributes]
|
|
if !ok {
|
|
return attrs
|
|
}
|
|
m, ok := m_.(map[string]any)
|
|
if !ok {
|
|
level.Error(u.logger).Log(
|
|
"msg", fmt.Sprintf("unexpected type for %s: %T", extensionEnterpriseUserAttributes, m_),
|
|
userNameAttr, userName,
|
|
)
|
|
return attrs
|
|
}
|
|
|
|
// Attempt to get department attribute.
|
|
if department_, ok := m[departmentAttr]; ok {
|
|
if department, ok := department_.(string); ok {
|
|
attrs.department = &department
|
|
} else {
|
|
level.Error(u.logger).Log(
|
|
"msg", fmt.Sprintf("unexpected type for %s.department: %T", extensionEnterpriseUserAttributes, department_),
|
|
userNameAttr, userName,
|
|
)
|
|
}
|
|
}
|
|
|
|
return attrs
|
|
}
|
|
|
|
func getRequiredResource[T string | bool](attributes scim.ResourceAttributes, key string) (T, error) {
|
|
var val T
|
|
valIntf, ok := attributes[key]
|
|
if !ok || valIntf == nil {
|
|
return val, errors.ScimErrorBadParams([]string{key})
|
|
}
|
|
val, ok = valIntf.(T)
|
|
if !ok {
|
|
return val, errors.ScimErrorBadParams([]string{key})
|
|
}
|
|
return val, nil
|
|
}
|
|
|
|
func getOptionalResource[T string | bool](attributes scim.ResourceAttributes, key string) (*T, error) {
|
|
var valPtr *T
|
|
valIntf, ok := attributes[key]
|
|
if ok && valIntf != nil {
|
|
val, ok := valIntf.(T)
|
|
if !ok {
|
|
return nil, errors.ScimErrorBadParams([]string{key})
|
|
}
|
|
valPtr = &val
|
|
}
|
|
return valPtr, nil
|
|
}
|
|
|
|
func getComplexResource(attributes scim.ResourceAttributes, key string) (map[string]interface{}, error) {
|
|
valIntf, ok := attributes[key]
|
|
if ok && valIntf != nil {
|
|
val, ok := valIntf.(map[string]interface{})
|
|
if !ok {
|
|
return nil, errors.ScimErrorBadParams([]string{key})
|
|
}
|
|
return val, nil
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func getComplexResourceSlice(attributes scim.ResourceAttributes, key string) ([]map[string]interface{}, error) {
|
|
valIntf, ok := attributes[key]
|
|
if ok && valIntf != nil {
|
|
valSliceIntf, ok := valIntf.([]interface{})
|
|
if !ok {
|
|
return nil, errors.ScimErrorBadParams([]string{key})
|
|
}
|
|
val := make([]map[string]interface{}, 0, len(valSliceIntf))
|
|
for _, v := range valSliceIntf {
|
|
valMap, ok := v.(map[string]interface{})
|
|
if !ok {
|
|
return nil, errors.ScimErrorBadParams([]string{key})
|
|
}
|
|
if len(valMap) > 0 {
|
|
val = append(val, valMap)
|
|
}
|
|
}
|
|
return val, nil
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (u *UserHandler) Get(r *http.Request, id string) (scim.Resource, error) {
|
|
idUint, err := extractUserIDFromValue(id)
|
|
if err != nil {
|
|
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
|
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
|
}
|
|
|
|
user, err := u.ds.ScimUserByID(r.Context(), idUint)
|
|
switch {
|
|
case fleet.IsNotFound(err):
|
|
level.Info(u.logger).Log("msg", "failed to find user", "id", id)
|
|
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
|
case err != nil:
|
|
level.Error(u.logger).Log("msg", "failed to get user", "id", id, "err", err)
|
|
return scim.Resource{}, err
|
|
}
|
|
|
|
return createUserResource(user), nil
|
|
}
|
|
|
|
func createUserResource(user *fleet.ScimUser) scim.Resource {
|
|
userResource := scim.Resource{}
|
|
userResource.ID = scimUserID(user.ID)
|
|
if user.ExternalID != nil {
|
|
userResource.ExternalID = optional.NewString(*user.ExternalID)
|
|
}
|
|
userResource.Attributes = scim.ResourceAttributes{}
|
|
userResource.Attributes[userNameAttr] = user.UserName
|
|
if user.Active != nil {
|
|
userResource.Attributes[activeAttr] = *user.Active
|
|
}
|
|
if user.FamilyName != nil || user.GivenName != nil {
|
|
userResource.Attributes[nameAttr] = make(scim.ResourceAttributes)
|
|
if user.FamilyName != nil {
|
|
userResource.Attributes[nameAttr].(scim.ResourceAttributes)[familyNameAttr] = *user.FamilyName
|
|
}
|
|
if user.GivenName != nil {
|
|
userResource.Attributes[nameAttr].(scim.ResourceAttributes)[givenNameAttr] = *user.GivenName
|
|
}
|
|
}
|
|
if len(user.Emails) > 0 {
|
|
emails := make([]scim.ResourceAttributes, 0, len(user.Emails))
|
|
for _, email := range user.Emails {
|
|
emailResource := make(scim.ResourceAttributes)
|
|
emailResource[valueAttr] = email.Email
|
|
if email.Type != nil {
|
|
emailResource[typeAttr] = *email.Type
|
|
}
|
|
if email.Primary != nil {
|
|
emailResource[primaryAttr] = *email.Primary
|
|
}
|
|
emails = append(emails, emailResource)
|
|
}
|
|
userResource.Attributes[emailsAttr] = emails
|
|
}
|
|
if len(user.Groups) > 0 {
|
|
groups := make([]scim.ResourceAttributes, 0, len(user.Groups))
|
|
for _, group := range user.Groups {
|
|
groups = append(groups, map[string]interface{}{
|
|
valueAttr: scimGroupID(group.ID),
|
|
"$ref": "Groups/" + scimGroupID(group.ID),
|
|
"display": group.DisplayName,
|
|
})
|
|
}
|
|
userResource.Attributes[groupsAttr] = groups
|
|
}
|
|
if user.Department != nil {
|
|
extensionEnterpriseUserAttributesMap := make(scim.ResourceAttributes)
|
|
extensionEnterpriseUserAttributesMap[departmentAttr] = *user.Department
|
|
userResource.Attributes[extensionEnterpriseUserAttributes] = extensionEnterpriseUserAttributesMap
|
|
}
|
|
return userResource
|
|
}
|
|
|
|
// GetAll
|
|
// Pagination is 1-indexed on the startIndex. The startIndex is the index of the resource (not the index of the page), per RFC7644.
|
|
//
|
|
// Per RFC7644 3.4.2, SHOULD ignore any query parameters they do not recognize instead of rejecting the query for versioning compatibility reasons
|
|
// https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.2
|
|
//
|
|
// Providers MUST decline to filter results if the specified filter operation is not recognized and return an HTTP 400 error with a
|
|
// "scimType" error of "invalidFilter" and an appropriate human-readable response as per Section 3.12. For example, if a client specified an
|
|
// unsupported operator named 'regex', the service provider should specify an error response description identifying the client error,
|
|
// e.g., 'The operator 'regex' is not supported.'
|
|
//
|
|
// If a SCIM service provider determines that too many results would be returned the server base URI, the server SHALL reject the request by
|
|
// returning an HTTP response with HTTP status code 400 (Bad Request) and JSON attribute "scimType" set to "tooMany" (see Table 9).
|
|
//
|
|
// totalResults: The total number of results returned by the list or query operation. The value may be larger than the number of
|
|
// resources returned, such as when returning a single page (see Section 3.4.2.4) of results where multiple pages are available.
|
|
func (u *UserHandler) GetAll(r *http.Request, params scim.ListRequestParams) (scim.Page, error) {
|
|
startIndex := params.StartIndex
|
|
if startIndex < 1 {
|
|
startIndex = 1
|
|
}
|
|
count := params.Count
|
|
if count > maxResults {
|
|
return scim.Page{}, errors.ScimErrorTooMany
|
|
}
|
|
if count < 1 {
|
|
count = maxResults
|
|
}
|
|
|
|
opts := fleet.ScimUsersListOptions{
|
|
ScimListOptions: fleet.ScimListOptions{
|
|
StartIndex: uint(startIndex), // nolint:gosec // ignore G115
|
|
PerPage: uint(count), // nolint:gosec // ignore G115
|
|
},
|
|
}
|
|
resourceFilter := r.URL.Query().Get("filter")
|
|
if resourceFilter != "" {
|
|
expr, err := filter.ParseAttrExp([]byte(resourceFilter))
|
|
if err != nil {
|
|
level.Error(u.logger).Log("msg", "failed to parse filter", "filter", resourceFilter, "err", err)
|
|
return scim.Page{}, errors.ScimErrorInvalidFilter
|
|
}
|
|
if !strings.EqualFold(expr.AttributePath.String(), "userName") || expr.Operator != "eq" {
|
|
level.Info(u.logger).Log("msg", "unsupported filter", "filter", resourceFilter)
|
|
return scim.Page{}, nil
|
|
}
|
|
userName, ok := expr.CompareValue.(string)
|
|
if !ok {
|
|
level.Error(u.logger).Log("msg", "unsupported value", "value", expr.CompareValue)
|
|
return scim.Page{}, nil
|
|
}
|
|
|
|
// Decode URL-encoded characters in userName, which is required to pass Microsoft Entra ID SCIM Validator
|
|
userName, err = url.QueryUnescape(userName)
|
|
if err != nil {
|
|
level.Error(u.logger).Log("msg", "failed to decode userName", "userName", userName, "err", err)
|
|
return scim.Page{}, nil
|
|
}
|
|
opts.UserNameFilter = &userName
|
|
}
|
|
users, totalResults, err := u.ds.ListScimUsers(r.Context(), opts)
|
|
if err != nil {
|
|
level.Error(u.logger).Log("msg", "failed to list users", "err", err)
|
|
return scim.Page{}, err
|
|
}
|
|
|
|
result := scim.Page{
|
|
TotalResults: int(totalResults), // nolint:gosec // ignore G115
|
|
Resources: make([]scim.Resource, 0, len(users)),
|
|
}
|
|
for _, user := range users {
|
|
result.Resources = append(result.Resources, createUserResource(&user))
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (u *UserHandler) Replace(r *http.Request, id string, attributes scim.ResourceAttributes) (scim.Resource, error) {
|
|
idUint, err := extractUserIDFromValue(id)
|
|
if err != nil {
|
|
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
|
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
|
}
|
|
|
|
user, err := u.createUserFromAttributes(attributes)
|
|
if err != nil {
|
|
level.Error(u.logger).Log("msg", "failed to create user from attributes", "id", id, "err", err)
|
|
return scim.Resource{}, err
|
|
}
|
|
user.ID = idUint
|
|
// Username is unique, so we must check if another user already exists with that username to return a clear error
|
|
userWithSameUsername, err := u.ds.ScimUserByUserName(r.Context(), user.UserName)
|
|
switch {
|
|
case err != nil && !fleet.IsNotFound(err):
|
|
level.Error(u.logger).Log("msg", "failed to check for userName uniqueness", userNameAttr, user.UserName, "err", err)
|
|
return scim.Resource{}, err
|
|
case err == nil && user.ID != userWithSameUsername.ID:
|
|
level.Info(u.logger).Log("msg", "user already exists with this username", userNameAttr, user.UserName)
|
|
return scim.Resource{}, errors.ScimErrorUniqueness
|
|
// Otherwise, we assume that we are replacing the username with this operation.
|
|
}
|
|
|
|
err = u.ds.ReplaceScimUser(r.Context(), user)
|
|
switch {
|
|
case fleet.IsNotFound(err):
|
|
level.Info(u.logger).Log("msg", "failed to find user to replace", "id", id)
|
|
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
|
case err != nil:
|
|
level.Error(u.logger).Log("msg", "failed to replace user", "id", id, "err", err)
|
|
return scim.Resource{}, err
|
|
}
|
|
|
|
return createUserResource(user), nil
|
|
}
|
|
|
|
// Delete
|
|
// https://datatracker.ietf.org/doc/html/rfc7644#section-3.6
|
|
// MUST return a 404 (Not Found) error code for all operations associated with the previously deleted resource
|
|
func (u *UserHandler) Delete(r *http.Request, id string) error {
|
|
idUint, err := extractUserIDFromValue(id)
|
|
if err != nil {
|
|
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
|
return errors.ScimErrorResourceNotFound(id)
|
|
}
|
|
err = u.ds.DeleteScimUser(r.Context(), idUint)
|
|
switch {
|
|
case fleet.IsNotFound(err):
|
|
level.Info(u.logger).Log("msg", "failed to find user to delete", "id", id)
|
|
return errors.ScimErrorResourceNotFound(id)
|
|
case err != nil:
|
|
level.Error(u.logger).Log("msg", "failed to delete user", "id", id, "err", err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Patch - https://datatracker.ietf.org/doc/html/rfc7644#section-3.5.2
|
|
func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchOperation) (scim.Resource, error) {
|
|
idUint, err := extractUserIDFromValue(id)
|
|
if err != nil {
|
|
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
|
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
|
}
|
|
user, err := u.ds.ScimUserByID(r.Context(), idUint)
|
|
switch {
|
|
case fleet.IsNotFound(err):
|
|
level.Info(u.logger).Log("msg", "failed to find user to patch", "id", id)
|
|
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
|
case err != nil:
|
|
level.Error(u.logger).Log("msg", "failed to get user to patch", "id", id, "err", err)
|
|
return scim.Resource{}, err
|
|
}
|
|
|
|
for _, op := range operations {
|
|
if op.Op != scim.PatchOperationAdd && op.Op != scim.PatchOperationReplace && op.Op != scim.PatchOperationRemove {
|
|
level.Info(u.logger).Log("msg", "unsupported patch operation", "op", op.Op)
|
|
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
switch {
|
|
// If path is not specified, we look for the path in the value attribute.
|
|
case op.Path == nil:
|
|
if op.Op == scim.PatchOperationRemove {
|
|
level.Info(u.logger).Log("msg", "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op)
|
|
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
newValues, ok := op.Value.(map[string]interface{})
|
|
if !ok {
|
|
level.Info(u.logger).Log("msg", "unsupported patch value", "value", op.Value)
|
|
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
for k, v := range newValues {
|
|
switch k {
|
|
case externalIdAttr:
|
|
err = u.patchExternalId(op.Op, v, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case userNameAttr:
|
|
err = u.patchUserName(op.Op, v, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case activeAttr:
|
|
err = u.patchActive(op.Op, v, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case nameAttr + "." + givenNameAttr:
|
|
err = u.patchGivenName(op.Op, v, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case nameAttr + "." + familyNameAttr:
|
|
err = u.patchFamilyName(op.Op, v, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case nameAttr:
|
|
err = u.patchName(v, op, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case emailsAttr:
|
|
err = u.patchEmails(v, op, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case extensionEnterpriseUserAttributes + ":" + departmentAttr:
|
|
err = u.patchDepartment(op.Op, v, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
default:
|
|
level.Info(u.logger).Log("msg", "unsupported patch value field", "field", k)
|
|
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
}
|
|
case op.Path.String() == externalIdAttr:
|
|
err = u.patchExternalId(op.Op, op.Value, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case op.Path.String() == userNameAttr:
|
|
err = u.patchUserName(op.Op, op.Value, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case op.Path.String() == activeAttr:
|
|
err = u.patchActive(op.Op, op.Value, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case op.Path.String() == nameAttr+"."+givenNameAttr:
|
|
err = u.patchGivenName(op.Op, op.Value, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case op.Path.String() == nameAttr+"."+familyNameAttr:
|
|
err = u.patchFamilyName(op.Op, op.Value, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case op.Path.String() == nameAttr:
|
|
err = u.patchName(op.Value, op, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case op.Path.String() == emailsAttr:
|
|
err = u.patchEmails(op.Value, op, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case op.Path.AttributePath.String() == emailsAttr:
|
|
err = u.patchEmailsWithPathFiltering(op, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
case op.Path.AttributePath.String() == extensionEnterpriseUserAttributes+":"+departmentAttr:
|
|
err = u.patchDepartment(op.Op, op.Value, user)
|
|
if err != nil {
|
|
return scim.Resource{}, err
|
|
}
|
|
default:
|
|
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
|
|
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
}
|
|
|
|
if len(operations) != 0 {
|
|
err = u.ds.ReplaceScimUser(r.Context(), user)
|
|
switch {
|
|
case fleet.IsNotFound(err):
|
|
level.Info(u.logger).Log("msg", "failed to find user to patch", "id", id)
|
|
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
|
case err != nil:
|
|
level.Error(u.logger).Log("msg", "failed to patch user", "id", id, "err", err)
|
|
return scim.Resource{}, err
|
|
}
|
|
}
|
|
|
|
return createUserResource(user), nil
|
|
}
|
|
|
|
func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user *fleet.ScimUser) error {
|
|
emailType, err := u.getEmailType(op)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
emailFound := false
|
|
var emailIndex int
|
|
for i, email := range user.Emails {
|
|
if email.Type != nil && *email.Type == emailType {
|
|
emailIndex = i
|
|
emailFound = true
|
|
break
|
|
}
|
|
}
|
|
if !emailFound && op.Op != scim.PatchOperationAdd {
|
|
level.Info(u.logger).Log("msg", "email not found", "email_type", emailType, "op", fmt.Sprintf("%v", op))
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
if op.Path.SubAttribute == nil {
|
|
if op.Op == scim.PatchOperationRemove {
|
|
user.Emails = slices.Delete(user.Emails, emailIndex, emailIndex+1)
|
|
return nil
|
|
}
|
|
|
|
// For add and replace operations, we need to extract the emails
|
|
var emailsList []interface{}
|
|
// Handle different value formats
|
|
switch val := op.Value.(type) {
|
|
case []interface{}:
|
|
// Direct array of members
|
|
emailsList = val
|
|
case map[string]interface{}:
|
|
// Single member as a map
|
|
emailsList = []interface{}{val}
|
|
default:
|
|
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
|
|
switch op.Op {
|
|
case scim.PatchOperationReplace:
|
|
if len(emailsList) == 0 {
|
|
user.Emails = slices.Delete(user.Emails, emailIndex, emailIndex+1)
|
|
return nil
|
|
}
|
|
if len(emailsList) != 1 {
|
|
level.Info(u.logger).Log("msg", "only 1 email should be present for replacement", "emails", emailsList)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
userEmail, err := u.extractEmail(emailsList[0], op)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// If setting primary to true, then unset true from other emails
|
|
if userEmail.Primary != nil && *userEmail.Primary {
|
|
clearPrimaryFlagFromEmails(user)
|
|
}
|
|
user.Emails[emailIndex] = userEmail
|
|
case scim.PatchOperationAdd:
|
|
if len(emailsList) == 0 {
|
|
level.Info(u.logger).Log("msg", "no emails provided to add", "emails", emailsList)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
var newEmails []fleet.ScimUserEmail
|
|
for e := range emailsList {
|
|
userEmail, err := u.extractEmail(emailsList[e], op)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
userEmail.Type = &emailType
|
|
newEmails = append(newEmails, userEmail)
|
|
}
|
|
primaryExists, err := u.checkEmailPrimary(newEmails)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if primaryExists {
|
|
clearPrimaryFlagFromEmails(user)
|
|
}
|
|
user.Emails = append(user.Emails, newEmails...)
|
|
}
|
|
return nil
|
|
}
|
|
if op.Op == scim.PatchOperationAdd && !emailFound {
|
|
user.Emails = append(user.Emails, fleet.ScimUserEmail{
|
|
Type: ptr.String(emailType),
|
|
})
|
|
emailIndex = len(user.Emails) - 1
|
|
}
|
|
switch *op.Path.SubAttribute {
|
|
case primaryAttr:
|
|
if op.Op == scim.PatchOperationRemove {
|
|
user.Emails[emailIndex].Primary = nil
|
|
return nil
|
|
}
|
|
if op.Value == nil {
|
|
user.Emails[emailIndex].Primary = nil
|
|
return nil
|
|
}
|
|
primary, err := getConcreteType[bool](u, op.Value, primaryAttr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// If setting primary to true, then unset true from other emails
|
|
if primary {
|
|
clearPrimaryFlagFromEmails(user)
|
|
}
|
|
user.Emails[emailIndex].Primary = &primary
|
|
case valueAttr:
|
|
if op.Op == scim.PatchOperationRemove {
|
|
// The operation of removing an email value doesn't make sense, but we allow it.
|
|
user.Emails[emailIndex].Email = ""
|
|
return nil
|
|
}
|
|
value, err := getConcreteType[string](u, op.Value, valueAttr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user.Emails[emailIndex].Email = value
|
|
case typeAttr:
|
|
if op.Op == scim.PatchOperationRemove {
|
|
user.Emails[emailIndex].Type = nil
|
|
return nil
|
|
}
|
|
if op.Value == nil {
|
|
user.Emails[emailIndex].Type = nil
|
|
return nil
|
|
}
|
|
newEmailType, err := getConcreteType[string](u, op.Value, typeAttr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user.Emails[emailIndex].Type = &newEmailType
|
|
default:
|
|
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (u *UserHandler) getEmailType(op scim.PatchOperation) (string, error) {
|
|
attrExpression, ok := op.Path.ValueExpression.(*filter.AttributeExpression)
|
|
if !ok {
|
|
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
|
|
return "", errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
// Only matching by email type (work, etc.) is supported.
|
|
if attrExpression.AttributePath.String() != typeAttr || attrExpression.Operator != filter.EQ {
|
|
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path, "expression", attrExpression.AttributePath.String())
|
|
return "", errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
emailType, ok := attrExpression.CompareValue.(string)
|
|
if !ok {
|
|
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path, "compare_value", attrExpression.CompareValue)
|
|
return "", errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
return emailType, nil
|
|
}
|
|
|
|
func getConcreteType[T string | bool](u *UserHandler, v interface{}, name string) (T, error) {
|
|
concreteType, ok := v.(T)
|
|
if !ok {
|
|
var zeroValue T
|
|
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' value", name), "value", v)
|
|
return zeroValue, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
|
|
}
|
|
return concreteType, nil
|
|
}
|
|
|
|
func (u *UserHandler) patchFamilyName(op string, v interface{}, user *fleet.ScimUser) error {
|
|
if op == scim.PatchOperationRemove {
|
|
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr+"."+familyNameAttr)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
familyName, err := getConcreteType[string](u, v, nameAttr+"."+familyNameAttr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user.FamilyName = &familyName
|
|
return nil
|
|
}
|
|
|
|
func (u *UserHandler) patchGivenName(op string, v interface{}, user *fleet.ScimUser) error {
|
|
if op == scim.PatchOperationRemove {
|
|
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr+"."+givenNameAttr)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
givenName, err := getConcreteType[string](u, v, nameAttr+"."+givenNameAttr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user.GivenName = &givenName
|
|
return nil
|
|
}
|
|
|
|
func (u *UserHandler) patchActive(op string, v interface{}, user *fleet.ScimUser) error {
|
|
if op == scim.PatchOperationRemove || v == nil {
|
|
user.Active = nil
|
|
return nil
|
|
}
|
|
active, err := getConcreteType[bool](u, v, activeAttr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user.Active = &active
|
|
return nil
|
|
}
|
|
|
|
func (u *UserHandler) patchExternalId(op string, v interface{}, user *fleet.ScimUser) error {
|
|
if op == scim.PatchOperationRemove || v == nil {
|
|
user.ExternalID = nil
|
|
return nil
|
|
}
|
|
externalId, err := getConcreteType[string](u, v, externalIdAttr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user.ExternalID = ptr.String(externalId)
|
|
return nil
|
|
}
|
|
|
|
func (u *UserHandler) patchUserName(op string, v interface{}, user *fleet.ScimUser) error {
|
|
if op == scim.PatchOperationRemove {
|
|
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", userNameAttr)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
userName, err := getConcreteType[string](u, v, userNameAttr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if userName == "" {
|
|
level.Info(u.logger).Log("msg", fmt.Sprintf("'%s' cannot be empty", userNameAttr), "value", v)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
|
|
}
|
|
user.UserName = userName
|
|
return nil
|
|
}
|
|
|
|
func (u *UserHandler) patchDepartment(op string, v interface{}, user *fleet.ScimUser) error {
|
|
if op == scim.PatchOperationRemove || v == nil {
|
|
user.Department = nil
|
|
return nil
|
|
}
|
|
department, err := getConcreteType[string](u, v, departmentAttr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user.Department = &department
|
|
return nil
|
|
}
|
|
|
|
func clearPrimaryFlagFromEmails(user *fleet.ScimUser) {
|
|
for i, email := range user.Emails {
|
|
if email.Primary != nil && *email.Primary {
|
|
user.Emails[i].Primary = ptr.Bool(false)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (u *UserHandler) patchEmails(v interface{}, op scim.PatchOperation, user *fleet.ScimUser) error {
|
|
if op.Op == scim.PatchOperationRemove {
|
|
user.Emails = nil
|
|
return nil
|
|
}
|
|
|
|
// For add and replace operations, we need to extract the emails
|
|
var emailsList []interface{}
|
|
// Handle different value formats
|
|
switch val := v.(type) {
|
|
case []interface{}:
|
|
// Direct array of members
|
|
emailsList = val
|
|
case map[string]interface{}:
|
|
// Single member as a map
|
|
emailsList = []interface{}{val}
|
|
default:
|
|
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
|
|
if op.Op == scim.PatchOperationAdd && len(emailsList) == 0 {
|
|
level.Info(u.logger).Log("msg", "no emails provided to add", "emails", emailsList)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
// Convert the emails to the expected format
|
|
userEmails := make([]fleet.ScimUserEmail, 0, len(emailsList))
|
|
for _, emailIntf := range emailsList {
|
|
userEmail, err := u.extractEmail(emailIntf, op)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
userEmails = append(userEmails, userEmail)
|
|
}
|
|
primaryExists, err := u.checkEmailPrimary(userEmails)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if op.Op == scim.PatchOperationAdd {
|
|
if primaryExists {
|
|
// Clear the primary flag from current emails because we are merging the two email lists and a new email has that flag.
|
|
clearPrimaryFlagFromEmails(user)
|
|
}
|
|
userEmails = append(user.Emails, userEmails...)
|
|
}
|
|
|
|
user.Emails = userEmails
|
|
return nil
|
|
}
|
|
|
|
// checkEmailPrimary ensures at most one email is marked as primary
|
|
func (u *UserHandler) checkEmailPrimary(userEmails []fleet.ScimUserEmail) (bool, error) {
|
|
primaryEmailCount := 0
|
|
for _, email := range userEmails {
|
|
if email.Primary != nil && *email.Primary {
|
|
primaryEmailCount++
|
|
if primaryEmailCount > 1 {
|
|
level.Info(u.logger).Log("msg", "multiple primary emails found")
|
|
return false, errors.ScimErrorBadParams([]string{"Only one email can be marked as primary"})
|
|
}
|
|
}
|
|
}
|
|
return primaryEmailCount > 0, nil
|
|
}
|
|
|
|
func (u *UserHandler) extractEmail(emailIntf interface{}, op scim.PatchOperation) (fleet.ScimUserEmail, error) {
|
|
emailMap, ok := emailIntf.(map[string]interface{})
|
|
if !ok {
|
|
level.Info(u.logger).Log("msg", "email is not a map", "email", emailIntf)
|
|
return fleet.ScimUserEmail{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
|
|
// Extract the email value (required)
|
|
emailValue, ok := emailMap[valueAttr].(string)
|
|
if !ok || emailValue == "" {
|
|
level.Info(u.logger).Log("msg", "email value is missing or invalid", "email", emailMap)
|
|
return fleet.ScimUserEmail{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
|
|
// Normalize the email
|
|
normalizedEmail, err := normalizeEmail(emailValue)
|
|
if err != nil {
|
|
level.Info(u.logger).Log("msg", "failed to normalize email", "email", emailValue, "err", err)
|
|
return fleet.ScimUserEmail{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
|
|
// Create the email object
|
|
userEmail := fleet.ScimUserEmail{
|
|
Email: normalizedEmail,
|
|
Type: nil,
|
|
Primary: nil,
|
|
}
|
|
|
|
// Extract the type (optional)
|
|
if typeValue, ok := emailMap[typeAttr].(string); ok {
|
|
userEmail.Type = &typeValue
|
|
}
|
|
|
|
// Extract the primary flag (optional)
|
|
if primaryValue, ok := emailMap[primaryAttr].(bool); ok {
|
|
userEmail.Primary = &primaryValue
|
|
}
|
|
return userEmail, nil
|
|
}
|
|
|
|
func (u *UserHandler) patchName(v interface{}, op scim.PatchOperation, user *fleet.ScimUser) error {
|
|
if op.Op == scim.PatchOperationRemove {
|
|
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
name, ok := v.(map[string]interface{})
|
|
if !ok {
|
|
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", nameAttr), "value", op.Value)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
for nameKey, nameValue := range name {
|
|
switch nameKey {
|
|
case givenNameAttr:
|
|
givenName, ok := nameValue.(string)
|
|
if !ok {
|
|
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", nameAttr+"."+givenNameAttr), "value",
|
|
op.Value)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
user.GivenName = &givenName
|
|
case familyNameAttr:
|
|
familyName, ok := nameValue.(string)
|
|
if !ok {
|
|
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", nameAttr+"."+familyNameAttr), "value",
|
|
op.Value)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
user.FamilyName = &familyName
|
|
default:
|
|
level.Info(u.logger).Log("msg", "unsupported patch value field", "field", nameAttr+"."+nameKey)
|
|
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// normalizeEmail
|
|
// The local-part of a mailbox MUST BE treated as case sensitive.
|
|
// Mailbox domains follow normal DNS rules and are hence not case sensitive.
|
|
// https://datatracker.ietf.org/doc/html/rfc5321#section-2.4
|
|
func normalizeEmail(email string) (string, error) {
|
|
email = removeWhitespace(email)
|
|
emailParts := strings.SplitN(email, "@", 2)
|
|
if len(emailParts) != 2 {
|
|
return "", fmt.Errorf("invalid email %s", email)
|
|
}
|
|
emailParts[1] = strings.ToLower(emailParts[1])
|
|
return strings.Join(emailParts, "@"), nil
|
|
}
|
|
|
|
func removeWhitespace(str string) string {
|
|
return strings.Map(func(r rune) rune {
|
|
if unicode.IsSpace(r) {
|
|
return -1
|
|
}
|
|
return r
|
|
}, str)
|
|
}
|
|
|
|
func scimUserID(userID uint) string {
|
|
return fmt.Sprintf("%d", userID)
|
|
}
|
|
|
|
// extractUserIDFromValue extracts the user ID from a value like "123"
|
|
func extractUserIDFromValue(value string) (uint, error) {
|
|
id, err := strconv.ParseUint(value, 10, 64)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return uint(id), nil
|
|
}
|