mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 13:37:30 +00:00
Incremental migration to slog (#40120)
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #40054 # Checklist for submitter If some of the following don't apply, delete the relevant line. - [ ] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. - Already added in previous PR ## Testing - [x] Added/updated automated tests - [x] QA'd all new/changed functionality manually <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Updated internal logging infrastructure across multiple server components to use standardized logging methods and improved context propagation. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
c303f7f0e6
commit
70ffac6341
13 changed files with 284 additions and 272 deletions
|
|
@ -423,9 +423,9 @@ the way that the Fleet server works.
|
|||
ds = redisWrapperDS
|
||||
|
||||
resultStore := pubsub.NewRedisQueryResults(redisPool, config.Redis.DuplicateResults,
|
||||
logger.With("component", "query-results"),
|
||||
logger.SlogLogger().With("component", "query-results"),
|
||||
)
|
||||
liveQueryStore := live_query.NewRedisLiveQuery(redisPool, logger, liveQueryMemCacheDuration)
|
||||
liveQueryStore := live_query.NewRedisLiveQuery(redisPool, logger.SlogLogger(), liveQueryMemCacheDuration)
|
||||
ssoSessionStore := sso.NewSessionStore(redisPool)
|
||||
|
||||
// Set common configuration for all logging.
|
||||
|
|
@ -557,7 +557,7 @@ the way that the Fleet server works.
|
|||
var geoIP fleet.GeoIP
|
||||
geoIP = &fleet.NoOpGeoIP{}
|
||||
if config.GeoIP.DatabasePath != "" {
|
||||
maxmind, err := fleet.NewMaxMindGeoIP(logger, config.GeoIP.DatabasePath)
|
||||
maxmind, err := fleet.NewMaxMindGeoIP(logger.SlogLogger(), config.GeoIP.DatabasePath)
|
||||
if err != nil {
|
||||
level.Error(logger).Log("msg", "failed to initialize maxmind geoip, check database path", "database_path",
|
||||
config.GeoIP.DatabasePath, "error", err)
|
||||
|
|
@ -853,7 +853,7 @@ the way that the Fleet server works.
|
|||
}
|
||||
}
|
||||
|
||||
eh := errorstore.NewHandler(ctx, redisPool, logger, config.Logging.ErrorRetentionPeriod)
|
||||
eh := errorstore.NewHandler(ctx, redisPool, logger.SlogLogger(), config.Logging.ErrorRetentionPeriod)
|
||||
scepConfigMgr := eeservice.NewSCEPConfigService(logger, nil)
|
||||
digiCertService := digicert.NewService(digicert.WithLogger(logger))
|
||||
ctx = ctxerr.NewContext(ctx, eh)
|
||||
|
|
@ -1497,7 +1497,7 @@ the way that the Fleet server works.
|
|||
if err = service.RegisterSCEPProxy(rootMux, ds, logger, nil, &config); err != nil {
|
||||
initFatal(err, "setup SCEP proxy")
|
||||
}
|
||||
if err = scim.RegisterSCIM(rootMux, ds, svc, logger, &config); err != nil {
|
||||
if err = scim.RegisterSCIM(rootMux, ds, svc, logger.SlogLogger(), &config); err != nil {
|
||||
initFatal(err, "setup SCIM")
|
||||
}
|
||||
// Host identify and conditional access SCEP feature only works if a private key has been set up
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package scim
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
|
@ -12,8 +13,6 @@ import (
|
|||
"github.com/elimity-com/scim/errors"
|
||||
"github.com/elimity-com/scim/optional"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
kitlog "github.com/go-kit/log"
|
||||
"github.com/go-kit/log/level"
|
||||
"github.com/scim2/filter-parser/v2"
|
||||
)
|
||||
|
||||
|
|
@ -25,13 +24,13 @@ const (
|
|||
|
||||
type GroupHandler struct {
|
||||
ds fleet.Datastore
|
||||
logger kitlog.Logger
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// Compile-time check
|
||||
var _ scim.ResourceHandler = &GroupHandler{}
|
||||
|
||||
func NewGroupHandler(ds fleet.Datastore, logger kitlog.Logger) scim.ResourceHandler {
|
||||
func NewGroupHandler(ds fleet.Datastore, logger *slog.Logger) scim.ResourceHandler {
|
||||
return &GroupHandler{ds: ds, logger: logger}
|
||||
}
|
||||
|
||||
|
|
@ -39,7 +38,7 @@ func NewGroupHandler(ds fleet.Datastore, logger kitlog.Logger) scim.ResourceHand
|
|||
func (g *GroupHandler) Create(r *http.Request, attributes scim.ResourceAttributes) (scim.Resource, error) {
|
||||
displayName, err := getRequiredResource[string](attributes, displayNameAttr)
|
||||
if err != nil {
|
||||
level.Error(g.logger).Log("msg", "failed to get displayName", "err", err)
|
||||
g.logger.ErrorContext(r.Context(), "failed to get displayName", "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
||||
|
|
@ -50,16 +49,16 @@ func (g *GroupHandler) Create(r *http.Request, attributes scim.ResourceAttribute
|
|||
_, err = g.ds.ScimGroupByDisplayName(r.Context(), displayName)
|
||||
switch {
|
||||
case err != nil && !fleet.IsNotFound(err):
|
||||
level.Error(g.logger).Log("msg", "failed to check for displayName uniqueness", displayNameAttr, displayName, "err", err)
|
||||
g.logger.ErrorContext(r.Context(), "failed to check for displayName uniqueness", displayNameAttr, displayName, "err", err)
|
||||
return scim.Resource{}, err
|
||||
case err == nil:
|
||||
level.Info(g.logger).Log("msg", "group already exists", displayNameAttr, displayName)
|
||||
g.logger.InfoContext(r.Context(), "group already exists", displayNameAttr, displayName)
|
||||
return scim.Resource{}, errors.ScimErrorUniqueness
|
||||
}
|
||||
|
||||
group, err := createGroupFromAttributes(attributes)
|
||||
if err != nil {
|
||||
level.Error(g.logger).Log("msg", "failed to create group from attributes", displayNameAttr, displayName, "err", err)
|
||||
g.logger.ErrorContext(r.Context(), "failed to create group from attributes", displayNameAttr, displayName, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
group.ID, err = g.ds.CreateScimGroup(r.Context(), group)
|
||||
|
|
@ -135,17 +134,17 @@ func areMembersExcluded(r *http.Request) bool {
|
|||
func (g *GroupHandler) Get(r *http.Request, id string) (scim.Resource, error) {
|
||||
idUint, err := extractGroupIDFromValue(id)
|
||||
if err != nil {
|
||||
level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
||||
g.logger.InfoContext(r.Context(), "failed to parse id", "id", id, "err", err)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
}
|
||||
|
||||
group, err := g.ds.ScimGroupByID(r.Context(), idUint, areMembersExcluded(r))
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(g.logger).Log("msg", "failed to find group", "id", id)
|
||||
g.logger.InfoContext(r.Context(), "failed to find group", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(g.logger).Log("msg", "failed to get group", "id", id, "err", err)
|
||||
g.logger.ErrorContext(r.Context(), "failed to get group", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
||||
|
|
@ -203,23 +202,23 @@ func (g *GroupHandler) GetAll(r *http.Request, params scim.ListRequestParams) (s
|
|||
if resourceFilter != "" {
|
||||
expr, err := filter.ParseAttrExp([]byte(resourceFilter))
|
||||
if err != nil {
|
||||
level.Error(g.logger).Log("msg", "failed to parse filter", "filter", resourceFilter, "err", err)
|
||||
g.logger.ErrorContext(r.Context(), "failed to parse filter", "filter", resourceFilter, "err", err)
|
||||
return scim.Page{}, errors.ScimErrorInvalidFilter
|
||||
}
|
||||
if !strings.EqualFold(expr.AttributePath.String(), "displayName") || expr.Operator != "eq" {
|
||||
level.Info(g.logger).Log("msg", "unsupported filter", "filter", resourceFilter)
|
||||
g.logger.InfoContext(r.Context(), "unsupported filter", "filter", resourceFilter)
|
||||
return scim.Page{}, nil
|
||||
}
|
||||
displayName, ok := expr.CompareValue.(string)
|
||||
if !ok {
|
||||
level.Error(g.logger).Log("msg", "unsupported value", "value", expr.CompareValue)
|
||||
g.logger.ErrorContext(r.Context(), "unsupported value", "value", expr.CompareValue)
|
||||
return scim.Page{}, nil
|
||||
}
|
||||
|
||||
// Decode URL-encoded characters
|
||||
displayName, err = url.QueryUnescape(displayName)
|
||||
if err != nil {
|
||||
level.Error(g.logger).Log("msg", "failed to decode displayName", "displayName", displayName, "err", err)
|
||||
g.logger.ErrorContext(r.Context(), "failed to decode displayName", "displayName", displayName, "err", err)
|
||||
return scim.Page{}, nil
|
||||
}
|
||||
opts.DisplayNameFilter = &displayName
|
||||
|
|
@ -227,7 +226,7 @@ func (g *GroupHandler) GetAll(r *http.Request, params scim.ListRequestParams) (s
|
|||
|
||||
groups, totalResults, err := g.ds.ListScimGroups(r.Context(), opts)
|
||||
if err != nil {
|
||||
level.Error(g.logger).Log("msg", "failed to list groups", "err", err)
|
||||
g.logger.ErrorContext(r.Context(), "failed to list groups", "err", err)
|
||||
return scim.Page{}, err
|
||||
}
|
||||
|
||||
|
|
@ -245,13 +244,13 @@ func (g *GroupHandler) GetAll(r *http.Request, params scim.ListRequestParams) (s
|
|||
func (g *GroupHandler) Replace(r *http.Request, id string, attributes scim.ResourceAttributes) (scim.Resource, error) {
|
||||
idUint, err := extractGroupIDFromValue(id)
|
||||
if err != nil {
|
||||
level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
||||
g.logger.InfoContext(r.Context(), "failed to parse id", "id", id, "err", err)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
}
|
||||
|
||||
group, err := createGroupFromAttributes(attributes)
|
||||
if err != nil {
|
||||
level.Error(g.logger).Log("msg", "failed to create group from attributes", "id", id, "err", err)
|
||||
g.logger.ErrorContext(r.Context(), "failed to create group from attributes", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
group.ID = idUint
|
||||
|
|
@ -260,10 +259,10 @@ func (g *GroupHandler) Replace(r *http.Request, id string, attributes scim.Resou
|
|||
groupWithSameDisplayName, err := g.ds.ScimGroupByDisplayName(r.Context(), group.DisplayName)
|
||||
switch {
|
||||
case err != nil && !fleet.IsNotFound(err):
|
||||
level.Error(g.logger).Log("msg", "failed to check for displayName uniqueness", displayNameAttr, group.DisplayName, "err", err)
|
||||
g.logger.ErrorContext(r.Context(), "failed to check for displayName uniqueness", displayNameAttr, group.DisplayName, "err", err)
|
||||
return scim.Resource{}, err
|
||||
case err == nil && group.ID != groupWithSameDisplayName.ID:
|
||||
level.Info(g.logger).Log("msg", "group already exists with this displayName", displayNameAttr, group.DisplayName)
|
||||
g.logger.InfoContext(r.Context(), "group already exists with this displayName", displayNameAttr, group.DisplayName)
|
||||
return scim.Resource{}, errors.ScimErrorUniqueness
|
||||
// Otherwise, we assume that we are replacing the displayName with this operation.
|
||||
}
|
||||
|
|
@ -271,10 +270,10 @@ func (g *GroupHandler) Replace(r *http.Request, id string, attributes scim.Resou
|
|||
err = g.ds.ReplaceScimGroup(r.Context(), group)
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(g.logger).Log("msg", "failed to find group to replace", "id", id)
|
||||
g.logger.InfoContext(r.Context(), "failed to find group to replace", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(g.logger).Log("msg", "failed to replace group", "id", id, "err", err)
|
||||
g.logger.ErrorContext(r.Context(), "failed to replace group", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
||||
|
|
@ -284,16 +283,16 @@ func (g *GroupHandler) Replace(r *http.Request, id string, attributes scim.Resou
|
|||
func (g *GroupHandler) Delete(r *http.Request, id string) error {
|
||||
idUint, err := extractGroupIDFromValue(id)
|
||||
if err != nil {
|
||||
level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
||||
g.logger.InfoContext(r.Context(), "failed to parse id", "id", id, "err", err)
|
||||
return errors.ScimErrorResourceNotFound(id)
|
||||
}
|
||||
err = g.ds.DeleteScimGroup(r.Context(), idUint)
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(g.logger).Log("msg", "failed to find group to delete", "id", id)
|
||||
g.logger.InfoContext(r.Context(), "failed to find group to delete", "id", id)
|
||||
return errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(g.logger).Log("msg", "failed to delete group", "id", id, "err", err)
|
||||
g.logger.ErrorContext(r.Context(), "failed to delete group", "id", id, "err", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
|
@ -302,93 +301,94 @@ func (g *GroupHandler) Delete(r *http.Request, id string) error {
|
|||
// Patch
|
||||
// Supporting add/replace/remove operations for "displayName", "externalId", and "members" attributes.
|
||||
func (g *GroupHandler) Patch(r *http.Request, id string, operations []scim.PatchOperation) (scim.Resource, error) {
|
||||
ctx := r.Context()
|
||||
idUint, err := extractGroupIDFromValue(id)
|
||||
if err != nil {
|
||||
level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
||||
g.logger.InfoContext(ctx, "failed to parse id", "id", id, "err", err)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
}
|
||||
group, err := g.ds.ScimGroupByID(r.Context(), idUint, false)
|
||||
group, err := g.ds.ScimGroupByID(ctx, idUint, false)
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(g.logger).Log("msg", "failed to find group to patch", "id", id)
|
||||
g.logger.InfoContext(ctx, "failed to find group to patch", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(g.logger).Log("msg", "failed to get group to patch", "id", id, "err", err)
|
||||
g.logger.ErrorContext(ctx, "failed to get group to patch", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
||||
for _, op := range operations {
|
||||
if op.Op != scim.PatchOperationAdd && op.Op != scim.PatchOperationReplace && op.Op != scim.PatchOperationRemove {
|
||||
level.Info(g.logger).Log("msg", "unsupported patch operation", "op", op.Op)
|
||||
g.logger.InfoContext(ctx, "unsupported patch operation", "op", op.Op)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
switch {
|
||||
case op.Path == nil:
|
||||
if op.Op == scim.PatchOperationRemove {
|
||||
level.Info(g.logger).Log("msg", "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op)
|
||||
g.logger.InfoContext(ctx, "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
newValues, ok := op.Value.(map[string]interface{})
|
||||
if !ok {
|
||||
level.Info(g.logger).Log("msg", "unsupported patch value", "value", op.Value)
|
||||
g.logger.InfoContext(ctx, "unsupported patch value", "value", op.Value)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
for k, v := range newValues {
|
||||
switch k {
|
||||
case externalIdAttr:
|
||||
err = g.patchExternalId(op.Op, v, group)
|
||||
err = g.patchExternalId(ctx, op.Op, v, group)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case displayNameAttr:
|
||||
err = g.patchDisplayName(op.Op, v, group)
|
||||
err = g.patchDisplayName(ctx, op.Op, v, group)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case membersAttr:
|
||||
err = g.patchMembers(r.Context(), op.Op, v, group)
|
||||
err = g.patchMembers(ctx, op.Op, v, group)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
default:
|
||||
level.Info(g.logger).Log("msg", "unsupported patch value field", "field", k)
|
||||
g.logger.InfoContext(ctx, "unsupported patch value field", "field", k)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
}
|
||||
case op.Path.String() == externalIdAttr:
|
||||
err = g.patchExternalId(op.Op, op.Value, group)
|
||||
err = g.patchExternalId(ctx, op.Op, op.Value, group)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.String() == displayNameAttr:
|
||||
err = g.patchDisplayName(op.Op, op.Value, group)
|
||||
err = g.patchDisplayName(ctx, op.Op, op.Value, group)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.String() == membersAttr:
|
||||
err = g.patchMembers(r.Context(), op.Op, op.Value, group)
|
||||
err = g.patchMembers(ctx, op.Op, op.Value, group)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.AttributePath.String() == membersAttr:
|
||||
err = g.patchMembersWithPathFiltering(r.Context(), op, group)
|
||||
err = g.patchMembersWithPathFiltering(ctx, op, group)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
default:
|
||||
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path)
|
||||
g.logger.InfoContext(ctx, "unsupported patch path", "path", op.Path)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
}
|
||||
|
||||
if len(operations) != 0 {
|
||||
err = g.ds.ReplaceScimGroup(r.Context(), group)
|
||||
err = g.ds.ReplaceScimGroup(ctx, group)
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(g.logger).Log("msg", "failed to find group to patch", "id", id)
|
||||
g.logger.InfoContext(ctx, "failed to find group to patch", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(g.logger).Log("msg", "failed to patch group", "id", id, "err", err)
|
||||
g.logger.ErrorContext(ctx, "failed to patch group", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
}
|
||||
|
|
@ -396,32 +396,32 @@ func (g *GroupHandler) Patch(r *http.Request, id string, operations []scim.Patch
|
|||
return createGroupResource(group), nil
|
||||
}
|
||||
|
||||
func (g *GroupHandler) patchExternalId(op string, v interface{}, group *fleet.ScimGroup) error {
|
||||
func (g *GroupHandler) patchExternalId(ctx context.Context, op string, v any, group *fleet.ScimGroup) error {
|
||||
if op == scim.PatchOperationRemove || v == nil {
|
||||
group.ExternalID = nil
|
||||
return nil
|
||||
}
|
||||
externalId, ok := v.(string)
|
||||
if !ok {
|
||||
level.Info(g.logger).Log("msg", fmt.Sprintf("unsupported '%s' value", externalIdAttr), "value", v)
|
||||
g.logger.InfoContext(ctx, fmt.Sprintf("unsupported '%s' value", externalIdAttr), "value", v)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
|
||||
}
|
||||
group.ExternalID = &externalId
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GroupHandler) patchDisplayName(op string, v interface{}, group *fleet.ScimGroup) error {
|
||||
func (g *GroupHandler) patchDisplayName(ctx context.Context, op string, v any, group *fleet.ScimGroup) error {
|
||||
if op == scim.PatchOperationRemove {
|
||||
level.Info(g.logger).Log("msg", "cannot remove required attribute", "attribute", displayNameAttr)
|
||||
g.logger.InfoContext(ctx, "cannot remove required attribute", "attribute", displayNameAttr)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
displayName, ok := v.(string)
|
||||
if !ok {
|
||||
level.Info(g.logger).Log("msg", fmt.Sprintf("unsupported '%s' value", displayNameAttr), "value", v)
|
||||
g.logger.InfoContext(ctx, fmt.Sprintf("unsupported '%s' value", displayNameAttr), "value", v)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
|
||||
}
|
||||
if displayName == "" {
|
||||
level.Info(g.logger).Log("msg", fmt.Sprintf("'%s' cannot be empty", displayNameAttr), "value", v)
|
||||
g.logger.InfoContext(ctx, fmt.Sprintf("'%s' cannot be empty", displayNameAttr), "value", v)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
|
||||
}
|
||||
group.DisplayName = displayName
|
||||
|
|
@ -453,7 +453,7 @@ func (g *GroupHandler) patchMembers(ctx context.Context, op string, v interface{
|
|||
membersList = append(membersList, m)
|
||||
}
|
||||
default:
|
||||
level.Info(g.logger).Log("msg", "unsupported members value format", "value", v)
|
||||
g.logger.InfoContext(ctx, "unsupported members value format", "value", v)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
|
||||
}
|
||||
|
||||
|
|
@ -464,20 +464,20 @@ func (g *GroupHandler) patchMembers(ctx context.Context, op string, v interface{
|
|||
for _, memberIntf := range membersList {
|
||||
member, ok := memberIntf.(map[string]interface{})
|
||||
if !ok {
|
||||
level.Info(g.logger).Log("msg", "member must be an object", "member", memberIntf)
|
||||
g.logger.InfoContext(ctx, "member must be an object", "member", memberIntf)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", memberIntf)})
|
||||
}
|
||||
|
||||
// Get the value attribute which contains the user ID
|
||||
valueIntf, ok := member["value"]
|
||||
if !ok || valueIntf == nil {
|
||||
level.Info(g.logger).Log("msg", "member missing value attribute", "member", member)
|
||||
g.logger.InfoContext(ctx, "member missing value attribute", "member", member)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", member)})
|
||||
}
|
||||
|
||||
valueStr, ok := valueIntf.(string)
|
||||
if !ok {
|
||||
level.Info(g.logger).Log("msg", "member value must be a string", "value", valueIntf)
|
||||
g.logger.InfoContext(ctx, "member value must be a string", "value", valueIntf)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", valueIntf)})
|
||||
}
|
||||
valueStrings = append(valueStrings, valueStr)
|
||||
|
|
@ -485,7 +485,7 @@ func (g *GroupHandler) patchMembers(ctx context.Context, op string, v interface{
|
|||
// Extract user ID from the value
|
||||
userID, err := extractUserIDFromValue(valueStr)
|
||||
if err != nil {
|
||||
level.Info(g.logger).Log("msg", "invalid user ID format", "value", valueStr, "err", err)
|
||||
g.logger.InfoContext(ctx, "invalid user ID format", "value", valueStr, "err", err)
|
||||
return errors.ScimErrorBadParams([]string{valueStr})
|
||||
}
|
||||
|
||||
|
|
@ -496,11 +496,11 @@ func (g *GroupHandler) patchMembers(ctx context.Context, op string, v interface{
|
|||
if len(userIDs) > 0 {
|
||||
allExist, err := g.ds.ScimUsersExist(ctx, userIDs)
|
||||
if err != nil {
|
||||
level.Error(g.logger).Log("msg", "error checking users existence", "err", err)
|
||||
g.logger.ErrorContext(ctx, "error checking users existence", "err", err)
|
||||
return err
|
||||
}
|
||||
if !allExist {
|
||||
level.Info(g.logger).Log("msg", "one or more users not found", "userIDs", userIDs)
|
||||
g.logger.InfoContext(ctx, "one or more users not found", "userIDs", userIDs)
|
||||
return errors.ScimErrorBadParams(valueStrings)
|
||||
}
|
||||
}
|
||||
|
|
@ -531,7 +531,7 @@ func (g *GroupHandler) patchMembers(ctx context.Context, op string, v interface{
|
|||
// patchMembersWithPathFiltering handles patch operations with path filtering for members
|
||||
// This supports paths like members[value eq "422"] for add/replace/remove operations
|
||||
func (g *GroupHandler) patchMembersWithPathFiltering(ctx context.Context, op scim.PatchOperation, group *fleet.ScimGroup) error {
|
||||
memberID, err := g.getMemberID(op)
|
||||
memberID, err := g.getMemberID(ctx, op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -550,7 +550,7 @@ func (g *GroupHandler) patchMembersWithPathFiltering(ctx context.Context, op sci
|
|||
// For remove operations, remove the member if found
|
||||
if op.Op == scim.PatchOperationRemove {
|
||||
if !memberFound {
|
||||
level.Info(g.logger).Log("msg", "member not found in group", "member_id", memberID, "op", fmt.Sprintf("%v", op))
|
||||
g.logger.InfoContext(ctx, "member not found in group", "member_id", memberID, "op", fmt.Sprintf("%v", op))
|
||||
// The member may have been removed already from this group. For example, if the member was deleted.
|
||||
return nil
|
||||
}
|
||||
|
|
@ -563,11 +563,11 @@ func (g *GroupHandler) patchMembersWithPathFiltering(ctx context.Context, op sci
|
|||
// Verify the user exists
|
||||
userExists, err := g.ds.ScimUsersExist(ctx, []uint{memberID})
|
||||
if err != nil {
|
||||
level.Error(g.logger).Log("msg", "error checking user existence", "err", err)
|
||||
g.logger.ErrorContext(ctx, "error checking user existence", "err", err)
|
||||
return err
|
||||
}
|
||||
if !userExists {
|
||||
level.Info(g.logger).Log("msg", "user not found", "user_id", memberID)
|
||||
g.logger.InfoContext(ctx, "user not found", "user_id", memberID)
|
||||
return errors.ScimErrorBadParams([]string{scimUserID(memberID)})
|
||||
}
|
||||
group.ScimUsers = append(group.ScimUsers, memberID)
|
||||
|
|
@ -577,7 +577,9 @@ func (g *GroupHandler) patchMembersWithPathFiltering(ctx context.Context, op sci
|
|||
// For replace operations with a value
|
||||
if op.Op == scim.PatchOperationReplace {
|
||||
if !memberFound {
|
||||
level.Info(g.logger).Log("msg", "member not found for replace operation", "members.value", memberID, "op", fmt.Sprintf("%v", op))
|
||||
g.logger.InfoContext(
|
||||
ctx, "member not found for replace operation", "members.value", memberID, "op", fmt.Sprintf("%v", op),
|
||||
)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
|
|
@ -596,29 +598,29 @@ func (g *GroupHandler) patchMembersWithPathFiltering(ctx context.Context, op sci
|
|||
}
|
||||
|
||||
// getMemberID extracts the member ID from a path expression like members[value eq "422"]
|
||||
func (g *GroupHandler) getMemberID(op scim.PatchOperation) (uint, error) {
|
||||
func (g *GroupHandler) getMemberID(ctx context.Context, op scim.PatchOperation) (uint, error) {
|
||||
attrExpression, ok := op.Path.ValueExpression.(*filter.AttributeExpression)
|
||||
if !ok {
|
||||
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path)
|
||||
g.logger.InfoContext(ctx, "unsupported patch path", "path", op.Path)
|
||||
return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
// Only matching by member value (user ID) is supported
|
||||
if attrExpression.AttributePath.String() != valueAttr || attrExpression.Operator != filter.EQ {
|
||||
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path, "expression", attrExpression.AttributePath.String())
|
||||
g.logger.InfoContext(ctx, "unsupported patch path", "path", op.Path, "expression", attrExpression.AttributePath.String())
|
||||
return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
memberIDStr, ok := attrExpression.CompareValue.(string)
|
||||
if !ok {
|
||||
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path, "compare_value", attrExpression.CompareValue)
|
||||
g.logger.InfoContext(ctx, "unsupported patch path", "path", op.Path, "compare_value", attrExpression.CompareValue)
|
||||
return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
// Extract user ID from the value
|
||||
userID, err := extractUserIDFromValue(memberIDStr)
|
||||
if err != nil {
|
||||
level.Info(g.logger).Log("msg", "invalid user ID format", "value", memberIDStr, "err", err)
|
||||
g.logger.InfoContext(ctx, "invalid user ID format", "value", memberIDStr, "err", err)
|
||||
return 0, errors.ScimErrorBadParams([]string{memberIDStr})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,10 +2,12 @@ package scim
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
|
|
@ -16,11 +18,8 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/authz"
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/platform/logging"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/auth"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/log"
|
||||
kitlog "github.com/go-kit/log"
|
||||
"github.com/go-kit/log/level"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
)
|
||||
|
||||
|
|
@ -32,7 +31,7 @@ func RegisterSCIM(
|
|||
mux *http.ServeMux,
|
||||
ds fleet.Datastore,
|
||||
svc fleet.Service,
|
||||
logger *logging.Logger,
|
||||
logger *slog.Logger,
|
||||
fleetConfig *config.FleetConfig,
|
||||
) error {
|
||||
if fleetConfig == nil {
|
||||
|
|
@ -201,7 +200,7 @@ func RegisterSCIM(
|
|||
}
|
||||
|
||||
serverOpts := []scim.ServerOption{
|
||||
scim.WithLogger(&scimErrorLogger{Logger: scimLogger}),
|
||||
scim.WithLogger(&scimErrorLogger{logger: scimLogger}),
|
||||
}
|
||||
|
||||
server, err := scim.NewServer(serverArgs, serverOpts...)
|
||||
|
|
@ -223,7 +222,7 @@ func RegisterSCIM(
|
|||
handler = AuthorizationMiddleware(authorizer, scimLogger, handler)
|
||||
handler = auth.AuthenticatedUserMiddleware(svc, scimErrorHandler, handler)
|
||||
handler = LastRequestMiddleware(ds, scimLogger, handler)
|
||||
handler = log.LogResponseEndMiddleware(scimLogger.SlogLogger(), handler)
|
||||
handler = log.LogResponseEndMiddleware(scimLogger, handler)
|
||||
handler = auth.SetRequestsContextMiddleware(svc, handler)
|
||||
return handler
|
||||
}
|
||||
|
|
@ -312,7 +311,7 @@ func scimOTELMiddleware(next http.Handler, prefix string, cfg config.FleetConfig
|
|||
|
||||
// LastRequestMiddleware saves the details of the last request to SCIM endpoints in the datastore.
|
||||
// These details can be used as a debug tool by the Fleet admin to see if SCIM integration is working.
|
||||
func LastRequestMiddleware(ds fleet.Datastore, logger kitlog.Logger, next http.Handler) http.Handler {
|
||||
func LastRequestMiddleware(ds fleet.Datastore, logger *slog.Logger, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
multi := newMultiResponseWriter(w)
|
||||
next.ServeHTTP(multi, r)
|
||||
|
|
@ -323,8 +322,7 @@ func LastRequestMiddleware(ds fleet.Datastore, logger kitlog.Logger, next http.H
|
|||
status = "success"
|
||||
case multi.statusCode == http.StatusUnauthorized:
|
||||
// We do not save unauthenticated error details; we simply log them.
|
||||
level.Info(logger).Log(
|
||||
"msg", "unauthenticated request",
|
||||
logger.InfoContext(r.Context(), "unauthenticated request",
|
||||
"origin", r.Header.Get("Origin"),
|
||||
"ip", r.RemoteAddr,
|
||||
"method", r.Method,
|
||||
|
|
@ -350,7 +348,7 @@ func LastRequestMiddleware(ds fleet.Datastore, logger kitlog.Logger, next http.H
|
|||
default:
|
||||
status = "error"
|
||||
details = fmt.Sprintf("Unhandled status code: %d", multi.statusCode)
|
||||
level.Error(logger).Log("msg", "unhandled status code", "status", multi.statusCode, "body", multi.body.String())
|
||||
logger.ErrorContext(r.Context(), "unhandled status code", "status", multi.statusCode, "body", multi.body.String())
|
||||
}
|
||||
if len(details) > fleet.SCIMMaxFieldLength {
|
||||
details = details[:fleet.SCIMMaxFieldLength]
|
||||
|
|
@ -360,12 +358,12 @@ func LastRequestMiddleware(ds fleet.Datastore, logger kitlog.Logger, next http.H
|
|||
Details: details,
|
||||
})
|
||||
if err != nil {
|
||||
level.Error(logger).Log("msg", "failed to update last scim request", "err", err)
|
||||
logger.ErrorContext(r.Context(), "failed to update last scim request", "err", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func AuthorizationMiddleware(authorizer *authz.Authorizer, logger kitlog.Logger, next http.Handler) http.Handler {
|
||||
func AuthorizationMiddleware(authorizer *authz.Authorizer, logger *slog.Logger, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := authorizer.Authorize(r.Context(), &fleet.ScimUser{}, fleet.ActionWrite)
|
||||
if err != nil {
|
||||
|
|
@ -376,14 +374,14 @@ func AuthorizationMiddleware(authorizer *authz.Authorizer, logger kitlog.Logger,
|
|||
})
|
||||
}
|
||||
|
||||
func errorHandler(w http.ResponseWriter, logger kitlog.Logger, detail string, status int) {
|
||||
func errorHandler(w http.ResponseWriter, logger *slog.Logger, detail string, status int) {
|
||||
scimErr := scimerrors.ScimError{
|
||||
Status: status,
|
||||
Detail: detail,
|
||||
}
|
||||
raw, err := json.Marshal(scimErr)
|
||||
if err != nil {
|
||||
level.Error(logger).Log("msg", "failed marshaling scim error", "scimError", scimErr, "err", err)
|
||||
logger.ErrorContext(context.TODO(), "failed marshaling scim error", "scimError", scimErr, "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -391,20 +389,18 @@ func errorHandler(w http.ResponseWriter, logger kitlog.Logger, detail string, st
|
|||
w.WriteHeader(scimErr.Status)
|
||||
_, err = w.Write(raw)
|
||||
if err != nil {
|
||||
level.Error(logger).Log("msg", "failed writing response", "err", err)
|
||||
logger.ErrorContext(context.TODO(), "failed writing response", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
type scimErrorLogger struct {
|
||||
kitlog.Logger
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
var _ scim.Logger = &scimErrorLogger{}
|
||||
|
||||
func (l *scimErrorLogger) Error(args ...interface{}) {
|
||||
level.Error(l.Logger).Log(
|
||||
"error", fmt.Sprint(args...),
|
||||
)
|
||||
l.logger.ErrorContext(context.TODO(), fmt.Sprint(args...))
|
||||
}
|
||||
|
||||
type multiResponseWriter struct {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package scim
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
|
|
@ -18,8 +19,6 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/service/modules/activities"
|
||||
kitlog "github.com/go-kit/log"
|
||||
"github.com/go-kit/log/level"
|
||||
"github.com/scim2/filter-parser/v2"
|
||||
)
|
||||
|
||||
|
|
@ -46,32 +45,34 @@ const (
|
|||
type UserHandler struct {
|
||||
ds fleet.Datastore
|
||||
activityModule activities.ActivityModule
|
||||
logger kitlog.Logger
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// Compile-time check
|
||||
var _ scim.ResourceHandler = &UserHandler{}
|
||||
|
||||
func NewUserHandler(ds fleet.Datastore, activityModule activities.ActivityModule, logger kitlog.Logger) scim.ResourceHandler {
|
||||
func NewUserHandler(ds fleet.Datastore, activityModule activities.ActivityModule, logger *slog.Logger) scim.ResourceHandler {
|
||||
return &UserHandler{ds: ds, activityModule: activityModule, logger: logger}
|
||||
}
|
||||
|
||||
func (u *UserHandler) Create(r *http.Request, attributes scim.ResourceAttributes) (scim.Resource, error) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Check for userName uniqueness
|
||||
userName, err := getRequiredResource[string](attributes, userNameAttr)
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to get userName", "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to get userName", "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
// In IETF documents, “non-empty” is generally used in the literal sense of “having at least one character.” That means if a value contains one or more spaces (and nothing else), it is still considered non-empty.
|
||||
if len(userName) == 0 {
|
||||
level.Info(u.logger).Log("msg", "userName is empty")
|
||||
u.logger.InfoContext(ctx, "userName is empty")
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{userNameAttr})
|
||||
}
|
||||
existingUser, err := u.ds.ScimUserByUserName(r.Context(), userName)
|
||||
existingUser, err := u.ds.ScimUserByUserName(ctx, userName)
|
||||
switch {
|
||||
case err != nil && !fleet.IsNotFound(err):
|
||||
level.Error(u.logger).Log("msg", "failed to check for userName uniqueness", userNameAttr, userName, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to check for userName uniqueness", userNameAttr, userName, "err", err)
|
||||
return scim.Resource{}, err
|
||||
case err == nil:
|
||||
// User exists - check if it's a deactivated user being reactivated
|
||||
|
|
@ -79,30 +80,31 @@ func (u *UserHandler) Create(r *http.Request, attributes scim.ResourceAttributes
|
|||
incomingActive, _ := getOptionalResource[bool](attributes, activeAttr)
|
||||
if existingUser.Active != nil && !*existingUser.Active && incomingActive != nil && *incomingActive {
|
||||
// Reactivate the user by updating their record
|
||||
level.Info(u.logger).Log("msg", "reactivating deactivated user", userNameAttr, userName)
|
||||
user, err := u.createUserFromAttributes(attributes)
|
||||
u.logger.InfoContext(ctx, "reactivating deactivated user", userNameAttr, userName)
|
||||
user, err := u.createUserFromAttributes(ctx, attributes)
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to create user from attributes for reactivation", userNameAttr, userName, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to create user from attributes for reactivation",
|
||||
userNameAttr, userName, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
user.ID = existingUser.ID
|
||||
err = u.ds.ReplaceScimUser(r.Context(), user)
|
||||
err = u.ds.ReplaceScimUser(ctx, user)
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to reactivate user", userNameAttr, userName, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to reactivate user", userNameAttr, userName, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
return createUserResource(user), nil
|
||||
}
|
||||
level.Info(u.logger).Log("msg", "user already exists", userNameAttr, userName)
|
||||
u.logger.InfoContext(ctx, "user already exists", userNameAttr, userName)
|
||||
return scim.Resource{}, errors.ScimErrorUniqueness
|
||||
}
|
||||
|
||||
user, err := u.createUserFromAttributes(attributes)
|
||||
user, err := u.createUserFromAttributes(ctx, attributes)
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to create user from attributes", userNameAttr, userName, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to create user from attributes", userNameAttr, userName, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
user.ID, err = u.ds.CreateScimUser(r.Context(), user)
|
||||
user.ID, err = u.ds.CreateScimUser(ctx, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
|
@ -110,7 +112,9 @@ func (u *UserHandler) Create(r *http.Request, attributes scim.ResourceAttributes
|
|||
return createUserResource(user), nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) createUserFromAttributes(attributes scim.ResourceAttributes) (*fleet.ScimUser, error) {
|
||||
func (u *UserHandler) createUserFromAttributes(
|
||||
ctx context.Context, attributes scim.ResourceAttributes,
|
||||
) (*fleet.ScimUser, error) {
|
||||
user := fleet.ScimUser{}
|
||||
var err error
|
||||
user.UserName, err = getRequiredResource[string](attributes, userNameAttr)
|
||||
|
|
@ -174,7 +178,7 @@ func (u *UserHandler) createUserFromAttributes(attributes scim.ResourceAttribute
|
|||
user.Emails = userEmails
|
||||
|
||||
// Attempt to get extension enterprise user attributes.
|
||||
extendedAttributes := u.getExtensionEnterpriseUserAttributes(user.UserName, attributes)
|
||||
extendedAttributes := u.getExtensionEnterpriseUserAttributes(ctx, user.UserName, attributes)
|
||||
user.Department = extendedAttributes.department
|
||||
|
||||
return &user, nil
|
||||
|
|
@ -184,7 +188,9 @@ type extendedAttributes struct {
|
|||
department *string
|
||||
}
|
||||
|
||||
func (u *UserHandler) getExtensionEnterpriseUserAttributes(userName string, attributes scim.ResourceAttributes) extendedAttributes {
|
||||
func (u *UserHandler) getExtensionEnterpriseUserAttributes(
|
||||
ctx context.Context, userName string, attributes scim.ResourceAttributes,
|
||||
) extendedAttributes {
|
||||
var attrs extendedAttributes
|
||||
m_, ok := attributes[extensionEnterpriseUserAttributes]
|
||||
if !ok {
|
||||
|
|
@ -192,8 +198,8 @@ func (u *UserHandler) getExtensionEnterpriseUserAttributes(userName string, attr
|
|||
}
|
||||
m, ok := m_.(map[string]any)
|
||||
if !ok {
|
||||
level.Error(u.logger).Log(
|
||||
"msg", fmt.Sprintf("unexpected type for %s: %T", extensionEnterpriseUserAttributes, m_),
|
||||
u.logger.ErrorContext(ctx,
|
||||
fmt.Sprintf("unexpected type for %s: %T", extensionEnterpriseUserAttributes, m_),
|
||||
userNameAttr, userName,
|
||||
)
|
||||
return attrs
|
||||
|
|
@ -204,8 +210,8 @@ func (u *UserHandler) getExtensionEnterpriseUserAttributes(userName string, attr
|
|||
if department, ok := department_.(string); ok {
|
||||
attrs.department = &department
|
||||
} else {
|
||||
level.Error(u.logger).Log(
|
||||
"msg", fmt.Sprintf("unexpected type for %s.department: %T", extensionEnterpriseUserAttributes, department_),
|
||||
u.logger.ErrorContext(ctx,
|
||||
fmt.Sprintf("unexpected type for %s.department: %T", extensionEnterpriseUserAttributes, department_),
|
||||
userNameAttr, userName,
|
||||
)
|
||||
}
|
||||
|
|
@ -275,19 +281,21 @@ func getComplexResourceSlice(attributes scim.ResourceAttributes, key string) ([]
|
|||
}
|
||||
|
||||
func (u *UserHandler) Get(r *http.Request, id string) (scim.Resource, error) {
|
||||
ctx := r.Context()
|
||||
|
||||
idUint, err := extractUserIDFromValue(id)
|
||||
if err != nil {
|
||||
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
||||
u.logger.InfoContext(ctx, "failed to parse id", "id", id, "err", err)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
}
|
||||
|
||||
user, err := u.ds.ScimUserByID(r.Context(), idUint)
|
||||
user, err := u.ds.ScimUserByID(ctx, idUint)
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(u.logger).Log("msg", "failed to find user", "id", id)
|
||||
u.logger.InfoContext(ctx, "failed to find user", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(u.logger).Log("msg", "failed to get user", "id", id, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to get user", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
||||
|
|
@ -365,6 +373,8 @@ func createUserResource(user *fleet.ScimUser) scim.Resource {
|
|||
// totalResults: The total number of results returned by the list or query operation. The value may be larger than the number of
|
||||
// resources returned, such as when returning a single page (see Section 3.4.2.4) of results where multiple pages are available.
|
||||
func (u *UserHandler) GetAll(r *http.Request, params scim.ListRequestParams) (scim.Page, error) {
|
||||
ctx := r.Context()
|
||||
|
||||
startIndex := params.StartIndex
|
||||
if startIndex < 1 {
|
||||
startIndex = 1
|
||||
|
|
@ -387,30 +397,30 @@ func (u *UserHandler) GetAll(r *http.Request, params scim.ListRequestParams) (sc
|
|||
if resourceFilter != "" {
|
||||
expr, err := filter.ParseAttrExp([]byte(resourceFilter))
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to parse filter", "filter", resourceFilter, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to parse filter", "filter", resourceFilter, "err", err)
|
||||
return scim.Page{}, errors.ScimErrorInvalidFilter
|
||||
}
|
||||
if !strings.EqualFold(expr.AttributePath.String(), "userName") || expr.Operator != "eq" {
|
||||
level.Info(u.logger).Log("msg", "unsupported filter", "filter", resourceFilter)
|
||||
u.logger.InfoContext(ctx, "unsupported filter", "filter", resourceFilter)
|
||||
return scim.Page{}, nil
|
||||
}
|
||||
userName, ok := expr.CompareValue.(string)
|
||||
if !ok {
|
||||
level.Error(u.logger).Log("msg", "unsupported value", "value", expr.CompareValue)
|
||||
u.logger.ErrorContext(ctx, "unsupported value", "value", expr.CompareValue)
|
||||
return scim.Page{}, nil
|
||||
}
|
||||
|
||||
// Decode URL-encoded characters in userName, which is required to pass Microsoft Entra ID SCIM Validator
|
||||
userName, err = url.QueryUnescape(userName)
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to decode userName", "userName", userName, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to decode userName", "userName", userName, "err", err)
|
||||
return scim.Page{}, nil
|
||||
}
|
||||
opts.UserNameFilter = &userName
|
||||
}
|
||||
users, totalResults, err := u.ds.ListScimUsers(r.Context(), opts)
|
||||
users, totalResults, err := u.ds.ListScimUsers(ctx, opts)
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to list users", "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to list users", "err", err)
|
||||
return scim.Page{}, err
|
||||
}
|
||||
|
||||
|
|
@ -430,13 +440,13 @@ func (u *UserHandler) Replace(r *http.Request, id string, attributes scim.Resour
|
|||
|
||||
idUint, err := extractUserIDFromValue(id)
|
||||
if err != nil {
|
||||
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
||||
u.logger.InfoContext(ctx, "failed to parse id", "id", id, "err", err)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
}
|
||||
|
||||
user, err := u.createUserFromAttributes(attributes)
|
||||
user, err := u.createUserFromAttributes(ctx, attributes)
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to create user from attributes", "id", id, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to create user from attributes", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
user.ID = idUint
|
||||
|
|
@ -447,10 +457,10 @@ func (u *UserHandler) Replace(r *http.Request, id string, attributes scim.Resour
|
|||
userWithSameUsername, err := u.ds.ScimUserByUserName(ctx, user.UserName)
|
||||
switch {
|
||||
case err != nil && !fleet.IsNotFound(err):
|
||||
level.Error(u.logger).Log("msg", "failed to check for userName uniqueness", userNameAttr, user.UserName, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to check for userName uniqueness", userNameAttr, user.UserName, "err", err)
|
||||
return scim.Resource{}, err
|
||||
case err == nil && user.ID != userWithSameUsername.ID:
|
||||
level.Info(u.logger).Log("msg", "user already exists with this username", userNameAttr, user.UserName)
|
||||
u.logger.InfoContext(ctx, "user already exists with this username", userNameAttr, user.UserName)
|
||||
return scim.Resource{}, errors.ScimErrorUniqueness
|
||||
case err == nil && user.ID == userWithSameUsername.ID:
|
||||
// Same user, username not changing - use this for previous active state
|
||||
|
|
@ -459,11 +469,11 @@ func (u *UserHandler) Replace(r *http.Request, id string, attributes scim.Resour
|
|||
// Username is being changed - need to fetch existing user by ID for previous active state
|
||||
existingUser, err := u.ds.ScimUserByID(ctx, idUint)
|
||||
if fleet.IsNotFound(err) {
|
||||
level.Info(u.logger).Log("msg", "failed to find scim user by id", "id", id)
|
||||
u.logger.InfoContext(ctx, "failed to find scim user by id", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
}
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to get existing scim user by id", "id", id, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to get existing scim user by id", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
previousActive = existingUser.Active
|
||||
|
|
@ -472,17 +482,17 @@ func (u *UserHandler) Replace(r *http.Request, id string, attributes scim.Resour
|
|||
err = u.ds.ReplaceScimUser(ctx, user)
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(u.logger).Log("msg", "failed to find user to replace", "id", id)
|
||||
u.logger.InfoContext(ctx, "failed to find user to replace", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(u.logger).Log("msg", "failed to replace user", "id", id, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to replace user", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
||||
// Check if user was deactivated and delete matching Fleet user if so
|
||||
if wasDeactivated(previousActive, user.Active) {
|
||||
if err := u.deleteMatchingFleetUser(ctx, user); err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to delete fleet user on deactivation", "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to delete fleet user on deactivation", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -497,33 +507,33 @@ func (u *UserHandler) Delete(r *http.Request, id string) error {
|
|||
|
||||
idUint, err := extractUserIDFromValue(id)
|
||||
if err != nil {
|
||||
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
||||
u.logger.InfoContext(ctx, "failed to parse id", "id", id, "err", err)
|
||||
return errors.ScimErrorResourceNotFound(id)
|
||||
}
|
||||
|
||||
scimUser, err := u.ds.ScimUserByID(ctx, idUint)
|
||||
if fleet.IsNotFound(err) {
|
||||
// proceed with DeleteScimUser call which calls triggerResendProfilesForIDPUserDeleted even before checking if the user exists
|
||||
level.Warn(u.logger).Log("msg", "scim user not found", "id", id)
|
||||
u.logger.WarnContext(ctx, "scim user not found", "id", id)
|
||||
} else if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to get scim user", "id", id, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to get scim user", "id", id, "err", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if scimUser != nil {
|
||||
if err := u.deleteMatchingFleetUser(ctx, scimUser); err != nil {
|
||||
// Log but don't fail - SCIM deletion should still proceed
|
||||
level.Error(u.logger).Log("msg", "failed to delete matching fleet user", "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to delete matching fleet user", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = u.ds.DeleteScimUser(ctx, idUint)
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(u.logger).Log("msg", "failed to find user to delete", "id", id)
|
||||
u.logger.InfoContext(ctx, "failed to find user to delete", "id", id)
|
||||
return errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(u.logger).Log("msg", "failed to delete user", "id", id, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to delete user", "id", id, "err", err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -556,7 +566,7 @@ func (u *UserHandler) deleteMatchingFleetUser(ctx context.Context, scimUser *fle
|
|||
emails = server.RemoveDuplicatesFromSlice(emails)
|
||||
|
||||
if len(emails) == 0 {
|
||||
level.Debug(u.logger).Log("msg", "no emails found for scim user",
|
||||
u.logger.DebugContext(ctx, "no emails found for scim user",
|
||||
"scim_user_id", scimUser.ID, "user_name", scimUser.UserName)
|
||||
return nil
|
||||
}
|
||||
|
|
@ -574,14 +584,14 @@ func (u *UserHandler) deleteMatchingFleetUser(ctx context.Context, scimUser *fle
|
|||
}
|
||||
|
||||
if fleetUser == nil {
|
||||
level.Debug(u.logger).Log("msg", "no matching fleet user found for scim user",
|
||||
u.logger.DebugContext(ctx, "no matching fleet user found for scim user",
|
||||
"scim_user_id", scimUser.ID, "user_name", scimUser.UserName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip API-only users or non-SSO users
|
||||
if fleetUser.APIOnly || !fleetUser.SSOEnabled {
|
||||
level.Info(u.logger).Log("msg", "skipping deletion of API-only or non-SSO user",
|
||||
u.logger.InfoContext(ctx, "skipping deletion of API-only or non-SSO user",
|
||||
"user_id", fleetUser.ID, "email", fleetUser.Email)
|
||||
return nil
|
||||
}
|
||||
|
|
@ -594,13 +604,13 @@ func (u *UserHandler) deleteMatchingFleetUser(ctx context.Context, scimUser *fle
|
|||
}
|
||||
|
||||
if count <= 1 {
|
||||
level.Warn(u.logger).Log("msg", "cannot delete last global admin via SCIM",
|
||||
u.logger.WarnContext(ctx, "cannot delete last global admin via SCIM",
|
||||
"user_id", fleetUser.ID, "email", fleetUser.Email)
|
||||
return ctxerr.New(ctx, "cannot delete last global admin")
|
||||
}
|
||||
}
|
||||
|
||||
level.Info(u.logger).Log("msg", "deleting fleet user via SCIM deletion",
|
||||
u.logger.InfoContext(ctx, "deleting fleet user via SCIM deletion",
|
||||
"user_id", fleetUser.ID, "email", fleetUser.Email)
|
||||
|
||||
// TODO: Ideally this should go through a Users service/module instead of directly accessing
|
||||
|
|
@ -620,7 +630,7 @@ func (u *UserHandler) deleteMatchingFleetUser(ctx context.Context, scimUser *fle
|
|||
FromScimUserDeletion: true,
|
||||
},
|
||||
); err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to create activity for fleet user deletion", "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to create activity for fleet user deletion", "err", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -632,16 +642,16 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
|
|||
|
||||
idUint, err := extractUserIDFromValue(id)
|
||||
if err != nil {
|
||||
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
||||
u.logger.InfoContext(ctx, "failed to parse id", "id", id, "err", err)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
}
|
||||
user, err := u.ds.ScimUserByID(ctx, idUint)
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(u.logger).Log("msg", "failed to find user to patch", "id", id)
|
||||
u.logger.InfoContext(ctx, "failed to find user to patch", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(u.logger).Log("msg", "failed to get user to patch", "id", id, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to get user to patch", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
||||
|
|
@ -650,115 +660,115 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
|
|||
|
||||
for _, op := range operations {
|
||||
if op.Op != scim.PatchOperationAdd && op.Op != scim.PatchOperationReplace && op.Op != scim.PatchOperationRemove {
|
||||
level.Info(u.logger).Log("msg", "unsupported patch operation", "op", op.Op)
|
||||
u.logger.InfoContext(ctx, "unsupported patch operation", "op", op.Op)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
switch {
|
||||
// If path is not specified, we look for the path in the value attribute.
|
||||
case op.Path == nil:
|
||||
if op.Op == scim.PatchOperationRemove {
|
||||
level.Info(u.logger).Log("msg", "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op)
|
||||
u.logger.InfoContext(ctx, "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
newValues, ok := op.Value.(map[string]interface{})
|
||||
if !ok {
|
||||
level.Info(u.logger).Log("msg", "unsupported patch value", "value", op.Value)
|
||||
u.logger.InfoContext(ctx, "unsupported patch value", "value", op.Value)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
for k, v := range newValues {
|
||||
switch k {
|
||||
case externalIdAttr:
|
||||
err = u.patchExternalId(op.Op, v, user)
|
||||
err = u.patchExternalId(ctx, op.Op, v, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case userNameAttr:
|
||||
err = u.patchUserName(op.Op, v, user)
|
||||
err = u.patchUserName(ctx, op.Op, v, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case activeAttr:
|
||||
err = u.patchActive(op.Op, v, user)
|
||||
err = u.patchActive(ctx, op.Op, v, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case nameAttr + "." + givenNameAttr:
|
||||
err = u.patchGivenName(op.Op, v, user)
|
||||
err = u.patchGivenName(ctx, op.Op, v, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case nameAttr + "." + familyNameAttr:
|
||||
err = u.patchFamilyName(op.Op, v, user)
|
||||
err = u.patchFamilyName(ctx, op.Op, v, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case nameAttr:
|
||||
err = u.patchName(v, op, user)
|
||||
err = u.patchName(ctx, v, op, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case emailsAttr:
|
||||
err = u.patchEmails(v, op, user)
|
||||
err = u.patchEmails(ctx, v, op, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case extensionEnterpriseUserAttributes + ":" + departmentAttr:
|
||||
err = u.patchDepartment(op.Op, v, user)
|
||||
err = u.patchDepartment(ctx, op.Op, v, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
default:
|
||||
level.Info(u.logger).Log("msg", "unsupported patch value field", "field", k)
|
||||
u.logger.InfoContext(ctx, "unsupported patch value field", "field", k)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
}
|
||||
case op.Path.String() == externalIdAttr:
|
||||
err = u.patchExternalId(op.Op, op.Value, user)
|
||||
err = u.patchExternalId(ctx, op.Op, op.Value, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.String() == userNameAttr:
|
||||
err = u.patchUserName(op.Op, op.Value, user)
|
||||
err = u.patchUserName(ctx, op.Op, op.Value, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.String() == activeAttr:
|
||||
err = u.patchActive(op.Op, op.Value, user)
|
||||
err = u.patchActive(ctx, op.Op, op.Value, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.String() == nameAttr+"."+givenNameAttr:
|
||||
err = u.patchGivenName(op.Op, op.Value, user)
|
||||
err = u.patchGivenName(ctx, op.Op, op.Value, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.String() == nameAttr+"."+familyNameAttr:
|
||||
err = u.patchFamilyName(op.Op, op.Value, user)
|
||||
err = u.patchFamilyName(ctx, op.Op, op.Value, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.String() == nameAttr:
|
||||
err = u.patchName(op.Value, op, user)
|
||||
err = u.patchName(ctx, op.Value, op, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.String() == emailsAttr:
|
||||
err = u.patchEmails(op.Value, op, user)
|
||||
err = u.patchEmails(ctx, op.Value, op, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.AttributePath.String() == emailsAttr:
|
||||
err = u.patchEmailsWithPathFiltering(op, user)
|
||||
err = u.patchEmailsWithPathFiltering(ctx, op, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
case op.Path.AttributePath.String() == extensionEnterpriseUserAttributes+":"+departmentAttr:
|
||||
err = u.patchDepartment(op.Op, op.Value, user)
|
||||
err = u.patchDepartment(ctx, op.Op, op.Value, user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
default:
|
||||
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
|
||||
u.logger.InfoContext(ctx, "unsupported patch path", "path", op.Path)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
}
|
||||
|
|
@ -767,17 +777,17 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
|
|||
err = u.ds.ReplaceScimUser(ctx, user)
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(u.logger).Log("msg", "failed to find user to patch", "id", id)
|
||||
u.logger.InfoContext(ctx, "failed to find user to patch", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(u.logger).Log("msg", "failed to patch user", "id", id, "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to patch user", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
||||
// Check if user was deactivated and delete matching Fleet user if so
|
||||
if wasDeactivated(previousActive, user.Active) {
|
||||
if err := u.deleteMatchingFleetUser(ctx, user); err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to delete fleet user on deactivation", "err", err)
|
||||
u.logger.ErrorContext(ctx, "failed to delete fleet user on deactivation", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -785,8 +795,10 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
|
|||
return createUserResource(user), nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user *fleet.ScimUser) error {
|
||||
emailType, err := u.getEmailType(op)
|
||||
func (u *UserHandler) patchEmailsWithPathFiltering(
|
||||
ctx context.Context, op scim.PatchOperation, user *fleet.ScimUser,
|
||||
) error {
|
||||
emailType, err := u.getEmailType(ctx, op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -800,7 +812,7 @@ func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user
|
|||
}
|
||||
}
|
||||
if !emailFound && op.Op != scim.PatchOperationAdd {
|
||||
level.Info(u.logger).Log("msg", "email not found", "email_type", emailType, "op", fmt.Sprintf("%v", op))
|
||||
u.logger.InfoContext(ctx, "email not found", "email_type", emailType, "op", fmt.Sprintf("%v", op))
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
if op.Path.SubAttribute == nil {
|
||||
|
|
@ -820,7 +832,7 @@ func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user
|
|||
// Single member as a map
|
||||
emailsList = []interface{}{val}
|
||||
default:
|
||||
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
|
||||
u.logger.InfoContext(ctx, fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
|
|
@ -831,10 +843,10 @@ func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user
|
|||
return nil
|
||||
}
|
||||
if len(emailsList) != 1 {
|
||||
level.Info(u.logger).Log("msg", "only 1 email should be present for replacement", "emails", emailsList)
|
||||
u.logger.InfoContext(ctx, "only 1 email should be present for replacement", "emails", emailsList)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
userEmail, err := u.extractEmail(emailsList[0], op)
|
||||
userEmail, err := u.extractEmail(ctx, emailsList[0], op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -845,19 +857,19 @@ func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user
|
|||
user.Emails[emailIndex] = userEmail
|
||||
case scim.PatchOperationAdd:
|
||||
if len(emailsList) == 0 {
|
||||
level.Info(u.logger).Log("msg", "no emails provided to add", "emails", emailsList)
|
||||
u.logger.InfoContext(ctx, "no emails provided to add", "emails", emailsList)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
var newEmails []fleet.ScimUserEmail
|
||||
for e := range emailsList {
|
||||
userEmail, err := u.extractEmail(emailsList[e], op)
|
||||
userEmail, err := u.extractEmail(ctx, emailsList[e], op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userEmail.Type = &emailType
|
||||
newEmails = append(newEmails, userEmail)
|
||||
}
|
||||
primaryExists, err := u.checkEmailPrimary(newEmails)
|
||||
primaryExists, err := u.checkEmailPrimary(ctx, newEmails)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -884,7 +896,7 @@ func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user
|
|||
user.Emails[emailIndex].Primary = nil
|
||||
return nil
|
||||
}
|
||||
primary, err := getConcreteType[bool](u, op.Value, primaryAttr)
|
||||
primary, err := getConcreteType[bool](ctx, u, op.Value, primaryAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -899,7 +911,7 @@ func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user
|
|||
user.Emails[emailIndex].Email = ""
|
||||
return nil
|
||||
}
|
||||
value, err := getConcreteType[string](u, op.Value, valueAttr)
|
||||
value, err := getConcreteType[string](ctx, u, op.Value, valueAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -913,53 +925,55 @@ func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user
|
|||
user.Emails[emailIndex].Type = nil
|
||||
return nil
|
||||
}
|
||||
newEmailType, err := getConcreteType[string](u, op.Value, typeAttr)
|
||||
newEmailType, err := getConcreteType[string](ctx, u, op.Value, typeAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.Emails[emailIndex].Type = &newEmailType
|
||||
default:
|
||||
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
|
||||
u.logger.InfoContext(ctx, "unsupported patch path", "path", op.Path)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) getEmailType(op scim.PatchOperation) (string, error) {
|
||||
func (u *UserHandler) getEmailType(ctx context.Context, op scim.PatchOperation) (string, error) {
|
||||
attrExpression, ok := op.Path.ValueExpression.(*filter.AttributeExpression)
|
||||
if !ok {
|
||||
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
|
||||
u.logger.InfoContext(ctx, "unsupported patch path", "path", op.Path)
|
||||
return "", errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
// Only matching by email type (work, etc.) is supported.
|
||||
if attrExpression.AttributePath.String() != typeAttr || attrExpression.Operator != filter.EQ {
|
||||
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path, "expression", attrExpression.AttributePath.String())
|
||||
u.logger.InfoContext(ctx, "unsupported patch path",
|
||||
"path", op.Path, "expression", attrExpression.AttributePath.String())
|
||||
return "", errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
emailType, ok := attrExpression.CompareValue.(string)
|
||||
if !ok {
|
||||
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path, "compare_value", attrExpression.CompareValue)
|
||||
u.logger.InfoContext(ctx, "unsupported patch path",
|
||||
"path", op.Path, "compare_value", attrExpression.CompareValue)
|
||||
return "", errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
return emailType, nil
|
||||
}
|
||||
|
||||
func getConcreteType[T string | bool](u *UserHandler, v interface{}, name string) (T, error) {
|
||||
func getConcreteType[T string | bool](ctx context.Context, u *UserHandler, v any, name string) (T, error) {
|
||||
concreteType, ok := v.(T)
|
||||
if !ok {
|
||||
var zeroValue T
|
||||
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' value", name), "value", v)
|
||||
u.logger.InfoContext(ctx, fmt.Sprintf("unsupported '%s' value", name), "value", v)
|
||||
return zeroValue, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
|
||||
}
|
||||
return concreteType, nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchFamilyName(op string, v interface{}, user *fleet.ScimUser) error {
|
||||
func (u *UserHandler) patchFamilyName(ctx context.Context, op string, v any, user *fleet.ScimUser) error {
|
||||
if op == scim.PatchOperationRemove {
|
||||
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr+"."+familyNameAttr)
|
||||
u.logger.InfoContext(ctx, "cannot remove required attribute", "attribute", nameAttr+"."+familyNameAttr)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
familyName, err := getConcreteType[string](u, v, nameAttr+"."+familyNameAttr)
|
||||
familyName, err := getConcreteType[string](ctx, u, v, nameAttr+"."+familyNameAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -967,12 +981,12 @@ func (u *UserHandler) patchFamilyName(op string, v interface{}, user *fleet.Scim
|
|||
return nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchGivenName(op string, v interface{}, user *fleet.ScimUser) error {
|
||||
func (u *UserHandler) patchGivenName(ctx context.Context, op string, v any, user *fleet.ScimUser) error {
|
||||
if op == scim.PatchOperationRemove {
|
||||
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr+"."+givenNameAttr)
|
||||
u.logger.InfoContext(ctx, "cannot remove required attribute", "attribute", nameAttr+"."+givenNameAttr)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
givenName, err := getConcreteType[string](u, v, nameAttr+"."+givenNameAttr)
|
||||
givenName, err := getConcreteType[string](ctx, u, v, nameAttr+"."+givenNameAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -980,12 +994,12 @@ func (u *UserHandler) patchGivenName(op string, v interface{}, user *fleet.ScimU
|
|||
return nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchActive(op string, v interface{}, user *fleet.ScimUser) error {
|
||||
func (u *UserHandler) patchActive(ctx context.Context, op string, v any, user *fleet.ScimUser) error {
|
||||
if op == scim.PatchOperationRemove || v == nil {
|
||||
user.Active = nil
|
||||
return nil
|
||||
}
|
||||
active, err := getConcreteType[bool](u, v, activeAttr)
|
||||
active, err := getConcreteType[bool](ctx, u, v, activeAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -993,12 +1007,12 @@ func (u *UserHandler) patchActive(op string, v interface{}, user *fleet.ScimUser
|
|||
return nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchExternalId(op string, v interface{}, user *fleet.ScimUser) error {
|
||||
func (u *UserHandler) patchExternalId(ctx context.Context, op string, v any, user *fleet.ScimUser) error {
|
||||
if op == scim.PatchOperationRemove || v == nil {
|
||||
user.ExternalID = nil
|
||||
return nil
|
||||
}
|
||||
externalId, err := getConcreteType[string](u, v, externalIdAttr)
|
||||
externalId, err := getConcreteType[string](ctx, u, v, externalIdAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -1006,29 +1020,29 @@ func (u *UserHandler) patchExternalId(op string, v interface{}, user *fleet.Scim
|
|||
return nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchUserName(op string, v interface{}, user *fleet.ScimUser) error {
|
||||
func (u *UserHandler) patchUserName(ctx context.Context, op string, v any, user *fleet.ScimUser) error {
|
||||
if op == scim.PatchOperationRemove {
|
||||
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", userNameAttr)
|
||||
u.logger.InfoContext(ctx, "cannot remove required attribute", "attribute", userNameAttr)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
userName, err := getConcreteType[string](u, v, userNameAttr)
|
||||
userName, err := getConcreteType[string](ctx, u, v, userNameAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if userName == "" {
|
||||
level.Info(u.logger).Log("msg", fmt.Sprintf("'%s' cannot be empty", userNameAttr), "value", v)
|
||||
u.logger.InfoContext(ctx, fmt.Sprintf("'%s' cannot be empty", userNameAttr), "value", v)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
|
||||
}
|
||||
user.UserName = userName
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchDepartment(op string, v interface{}, user *fleet.ScimUser) error {
|
||||
func (u *UserHandler) patchDepartment(ctx context.Context, op string, v any, user *fleet.ScimUser) error {
|
||||
if op == scim.PatchOperationRemove || v == nil {
|
||||
user.Department = nil
|
||||
return nil
|
||||
}
|
||||
department, err := getConcreteType[string](u, v, departmentAttr)
|
||||
department, err := getConcreteType[string](ctx, u, v, departmentAttr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -1044,7 +1058,9 @@ func clearPrimaryFlagFromEmails(user *fleet.ScimUser) {
|
|||
}
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchEmails(v interface{}, op scim.PatchOperation, user *fleet.ScimUser) error {
|
||||
func (u *UserHandler) patchEmails(
|
||||
ctx context.Context, v any, op scim.PatchOperation, user *fleet.ScimUser,
|
||||
) error {
|
||||
if op.Op == scim.PatchOperationRemove {
|
||||
user.Emails = nil
|
||||
return nil
|
||||
|
|
@ -1061,24 +1077,24 @@ func (u *UserHandler) patchEmails(v interface{}, op scim.PatchOperation, user *f
|
|||
// Single member as a map
|
||||
emailsList = []interface{}{val}
|
||||
default:
|
||||
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
|
||||
u.logger.InfoContext(ctx, fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
if op.Op == scim.PatchOperationAdd && len(emailsList) == 0 {
|
||||
level.Info(u.logger).Log("msg", "no emails provided to add", "emails", emailsList)
|
||||
u.logger.InfoContext(ctx, "no emails provided to add", "emails", emailsList)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
// Convert the emails to the expected format
|
||||
userEmails := make([]fleet.ScimUserEmail, 0, len(emailsList))
|
||||
for _, emailIntf := range emailsList {
|
||||
userEmail, err := u.extractEmail(emailIntf, op)
|
||||
userEmail, err := u.extractEmail(ctx, emailIntf, op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userEmails = append(userEmails, userEmail)
|
||||
}
|
||||
primaryExists, err := u.checkEmailPrimary(userEmails)
|
||||
primaryExists, err := u.checkEmailPrimary(ctx, userEmails)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -1096,13 +1112,13 @@ func (u *UserHandler) patchEmails(v interface{}, op scim.PatchOperation, user *f
|
|||
}
|
||||
|
||||
// checkEmailPrimary ensures at most one email is marked as primary
|
||||
func (u *UserHandler) checkEmailPrimary(userEmails []fleet.ScimUserEmail) (bool, error) {
|
||||
func (u *UserHandler) checkEmailPrimary(ctx context.Context, userEmails []fleet.ScimUserEmail) (bool, error) {
|
||||
primaryEmailCount := 0
|
||||
for _, email := range userEmails {
|
||||
if email.Primary != nil && *email.Primary {
|
||||
primaryEmailCount++
|
||||
if primaryEmailCount > 1 {
|
||||
level.Info(u.logger).Log("msg", "multiple primary emails found")
|
||||
u.logger.InfoContext(ctx, "multiple primary emails found")
|
||||
return false, errors.ScimErrorBadParams([]string{"Only one email can be marked as primary"})
|
||||
}
|
||||
}
|
||||
|
|
@ -1110,24 +1126,26 @@ func (u *UserHandler) checkEmailPrimary(userEmails []fleet.ScimUserEmail) (bool,
|
|||
return primaryEmailCount > 0, nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) extractEmail(emailIntf interface{}, op scim.PatchOperation) (fleet.ScimUserEmail, error) {
|
||||
func (u *UserHandler) extractEmail(
|
||||
ctx context.Context, emailIntf any, op scim.PatchOperation,
|
||||
) (fleet.ScimUserEmail, error) {
|
||||
emailMap, ok := emailIntf.(map[string]interface{})
|
||||
if !ok {
|
||||
level.Info(u.logger).Log("msg", "email is not a map", "email", emailIntf)
|
||||
u.logger.InfoContext(ctx, "email is not a map", "email", emailIntf)
|
||||
return fleet.ScimUserEmail{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
// Extract the email value (required)
|
||||
emailValue, ok := emailMap[valueAttr].(string)
|
||||
if !ok || emailValue == "" {
|
||||
level.Info(u.logger).Log("msg", "email value is missing or invalid", "email", emailMap)
|
||||
u.logger.InfoContext(ctx, "email value is missing or invalid", "email", emailMap)
|
||||
return fleet.ScimUserEmail{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
// Normalize the email
|
||||
normalizedEmail, err := normalizeEmail(emailValue)
|
||||
if err != nil {
|
||||
level.Info(u.logger).Log("msg", "failed to normalize email", "email", emailValue, "err", err)
|
||||
u.logger.InfoContext(ctx, "failed to normalize email", "email", emailValue, "err", err)
|
||||
return fleet.ScimUserEmail{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
|
||||
|
|
@ -1150,14 +1168,14 @@ func (u *UserHandler) extractEmail(emailIntf interface{}, op scim.PatchOperation
|
|||
return userEmail, nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) patchName(v interface{}, op scim.PatchOperation, user *fleet.ScimUser) error {
|
||||
func (u *UserHandler) patchName(ctx context.Context, v any, op scim.PatchOperation, user *fleet.ScimUser) error {
|
||||
if op.Op == scim.PatchOperationRemove {
|
||||
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr)
|
||||
u.logger.InfoContext(ctx, "cannot remove required attribute", "attribute", nameAttr)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
name, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", nameAttr), "value", op.Value)
|
||||
u.logger.InfoContext(ctx, fmt.Sprintf("unsupported '%s' patch value", nameAttr), "value", op.Value)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
for nameKey, nameValue := range name {
|
||||
|
|
@ -1165,21 +1183,21 @@ func (u *UserHandler) patchName(v interface{}, op scim.PatchOperation, user *fle
|
|||
case givenNameAttr:
|
||||
givenName, ok := nameValue.(string)
|
||||
if !ok {
|
||||
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", nameAttr+"."+givenNameAttr), "value",
|
||||
op.Value)
|
||||
u.logger.InfoContext(ctx,
|
||||
fmt.Sprintf("unsupported '%s' patch value", nameAttr+"."+givenNameAttr), "value", op.Value)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
user.GivenName = &givenName
|
||||
case familyNameAttr:
|
||||
familyName, ok := nameValue.(string)
|
||||
if !ok {
|
||||
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", nameAttr+"."+familyNameAttr), "value",
|
||||
op.Value)
|
||||
u.logger.InfoContext(ctx,
|
||||
fmt.Sprintf("unsupported '%s' patch value", nameAttr+"."+familyNameAttr), "value", op.Value)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
user.FamilyName = &familyName
|
||||
default:
|
||||
level.Info(u.logger).Log("msg", "unsupported patch value field", "field", nameAttr+"."+nameKey)
|
||||
u.logger.InfoContext(ctx, "unsupported patch value field", "field", nameAttr+"."+nameKey)
|
||||
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package scim
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
|
@ -12,7 +13,6 @@ import (
|
|||
mockservice "github.com/fleetdm/fleet/v4/server/mock/service"
|
||||
platform_mysql "github.com/fleetdm/fleet/v4/server/platform/mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
kitlog "github.com/go-kit/log"
|
||||
"github.com/scim2/filter-parser/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -34,7 +34,7 @@ func (m *testMocks) newTestHandler() *UserHandler {
|
|||
return &UserHandler{
|
||||
ds: m.ds,
|
||||
activityModule: m.svc,
|
||||
logger: kitlog.NewNopLogger(),
|
||||
logger: slog.New(slog.DiscardHandler),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
|
@ -23,8 +24,6 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/redis"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
kitlog "github.com/go-kit/log"
|
||||
"github.com/go-kit/log/level"
|
||||
redigo "github.com/gomodule/redigo/redis"
|
||||
)
|
||||
|
||||
|
|
@ -33,7 +32,7 @@ import (
|
|||
// from the store. It is safe to call those methods concurrently.
|
||||
type Handler struct {
|
||||
pool fleet.RedisPool
|
||||
logger kitlog.Logger
|
||||
logger *slog.Logger
|
||||
ttl time.Duration
|
||||
running int32 // accessed atomically
|
||||
errCh chan error
|
||||
|
|
@ -47,7 +46,7 @@ type Handler struct {
|
|||
// NewHandler creates an error handler using the provided pool and logger,
|
||||
// storing unique instances of errors in Redis using the pool. It stops storing
|
||||
// errors when ctx is cancelled. Errors are kept for the duration of ttl.
|
||||
func NewHandler(ctx context.Context, pool fleet.RedisPool, logger kitlog.Logger, ttl time.Duration) *Handler {
|
||||
func NewHandler(ctx context.Context, pool fleet.RedisPool, logger *slog.Logger, ttl time.Duration) *Handler {
|
||||
eh := &Handler{
|
||||
pool: pool,
|
||||
logger: logger,
|
||||
|
|
@ -60,7 +59,7 @@ func NewHandler(ctx context.Context, pool fleet.RedisPool, logger kitlog.Logger,
|
|||
return eh
|
||||
}
|
||||
|
||||
func newTestHandler(ctx context.Context, pool fleet.RedisPool, logger kitlog.Logger, ttl time.Duration, onStart func(), onStore func(error)) *Handler {
|
||||
func newTestHandler(ctx context.Context, pool fleet.RedisPool, logger *slog.Logger, ttl time.Duration, onStart func(), onStore func(error)) *Handler {
|
||||
eh := &Handler{
|
||||
pool: pool,
|
||||
logger: logger,
|
||||
|
|
@ -202,7 +201,7 @@ func (h *Handler) handleErrors(ctx context.Context) {
|
|||
func (h *Handler) storeError(ctx context.Context, err error) {
|
||||
errorHash, errorJson, err := hashAndMarshalError(err)
|
||||
if err != nil {
|
||||
level.Error(h.logger).Log("err", err, "msg", "hashErr failed")
|
||||
h.logger.ErrorContext(ctx, "hashErr failed", "err", err)
|
||||
if h.testOnStore != nil {
|
||||
h.testOnStore(err)
|
||||
}
|
||||
|
|
@ -238,7 +237,7 @@ func (h *Handler) storeError(ctx context.Context, err error) {
|
|||
}
|
||||
|
||||
if _, err := conn.Do(""); err != nil {
|
||||
level.Error(h.logger).Log("err", err, "msg", "redis SET failed")
|
||||
h.logger.ErrorContext(ctx, "redis SET failed", "err", err)
|
||||
if h.testOnStore != nil {
|
||||
h.testOnStore(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"regexp"
|
||||
|
|
@ -18,7 +19,6 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/redis/redistest"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
kitlog "github.com/go-kit/log"
|
||||
pkgErrors "github.com/pkg/errors" //nolint:depguard
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -150,7 +150,7 @@ func TestErrorHandler(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // cancel immediately
|
||||
|
||||
eh := newTestHandler(ctx, nil, kitlog.NewNopLogger(), time.Minute, nil, nil)
|
||||
eh := newTestHandler(ctx, nil, slog.New(slog.DiscardHandler), time.Minute, nil, nil)
|
||||
|
||||
doneCh := make(chan struct{})
|
||||
go func() {
|
||||
|
|
@ -168,7 +168,7 @@ func TestErrorHandler(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("works if the error storage is disabled", func(t *testing.T) {
|
||||
eh := newTestHandler(context.Background(), nil, kitlog.NewNopLogger(), -1, nil, nil)
|
||||
eh := newTestHandler(context.Background(), nil, slog.New(slog.DiscardHandler), -1, nil, nil)
|
||||
|
||||
doneCh := make(chan struct{})
|
||||
go func() {
|
||||
|
|
@ -218,7 +218,7 @@ func testErrorHandlerCollectsErrors(t *testing.T, pool fleet.RedisPool, wd strin
|
|||
close(chDone)
|
||||
}
|
||||
}
|
||||
eh := newTestHandler(ctx, pool, kitlog.NewNopLogger(), time.Minute, testOnStart, testOnStore)
|
||||
eh := newTestHandler(ctx, pool, slog.New(slog.DiscardHandler), time.Minute, testOnStart, testOnStore)
|
||||
|
||||
<-chGo
|
||||
|
||||
|
|
@ -279,7 +279,7 @@ func testErrorHandlerCollectsDifferentErrors(t *testing.T, pool fleet.RedisPool,
|
|||
}
|
||||
}
|
||||
|
||||
eh := newTestHandler(ctx, pool, kitlog.NewNopLogger(), time.Minute, testOnStart, testOnStore)
|
||||
eh := newTestHandler(ctx, pool, slog.New(slog.DiscardHandler), time.Minute, testOnStart, testOnStore)
|
||||
|
||||
<-chGo
|
||||
|
||||
|
|
@ -363,7 +363,7 @@ func TestHttpHandler(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
eh := newTestHandler(ctx, pool, kitlog.NewNopLogger(), time.Minute, testOnStart, testOnStore)
|
||||
eh := newTestHandler(ctx, pool, slog.New(slog.DiscardHandler), time.Minute, testOnStart, testOnStore)
|
||||
|
||||
<-chGo
|
||||
// simulate two errors, one happening twice
|
||||
|
|
|
|||
|
|
@ -3,10 +3,9 @@ package fleet
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net"
|
||||
|
||||
"github.com/go-kit/log"
|
||||
"github.com/go-kit/log/level"
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
)
|
||||
|
||||
|
|
@ -29,7 +28,7 @@ type GeoIP interface {
|
|||
|
||||
type MaxMindGeoIP struct {
|
||||
reader *geoip2.Reader
|
||||
l log.Logger
|
||||
l *slog.Logger
|
||||
}
|
||||
|
||||
type NoOpGeoIP struct{}
|
||||
|
|
@ -38,7 +37,7 @@ func (n *NoOpGeoIP) Lookup(ctx context.Context, ip string) *GeoLocation {
|
|||
return nil
|
||||
}
|
||||
|
||||
func NewMaxMindGeoIP(logger log.Logger, path string) (*MaxMindGeoIP, error) {
|
||||
func NewMaxMindGeoIP(logger *slog.Logger, path string) (*MaxMindGeoIP, error) {
|
||||
r, err := geoip2.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -59,7 +58,7 @@ func (m *MaxMindGeoIP) Lookup(ctx context.Context, ip string) *GeoLocation {
|
|||
if err != nil && errors.Is(err, notCityDBError) {
|
||||
resp, err := m.reader.Country(parseIP)
|
||||
if err != nil {
|
||||
level.Debug(m.l).Log("err", err, "msg", "failed to lookup location from mmdb file")
|
||||
m.l.DebugContext(ctx, "failed to lookup location from mmdb file", "err", err)
|
||||
return nil
|
||||
}
|
||||
if resp == nil {
|
||||
|
|
@ -69,7 +68,7 @@ func (m *MaxMindGeoIP) Lookup(ctx context.Context, ip string) *GeoLocation {
|
|||
return &GeoLocation{CountryISO: resp.Country.IsoCode}
|
||||
}
|
||||
if err != nil {
|
||||
level.Debug(m.l).Log("err", err, "msg", "failed to lookup location from mmdb file")
|
||||
m.l.DebugContext(ctx, "failed to lookup location from mmdb file", "err", err)
|
||||
return nil
|
||||
}
|
||||
return parseCity(resp)
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
|
@ -56,8 +57,6 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/redis"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
kitlog "github.com/go-kit/log"
|
||||
"github.com/go-kit/log/level"
|
||||
redigo "github.com/gomodule/redigo/redis"
|
||||
)
|
||||
|
||||
|
|
@ -78,7 +77,7 @@ type redisLiveQuery struct {
|
|||
// in memory cache expiration
|
||||
cacheExpiration time.Duration
|
||||
|
||||
logger kitlog.Logger
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// memCache is an in-memory cache for live queries. It stores the SQL of the
|
||||
|
|
@ -109,7 +108,7 @@ func (r *redisLiveQuery) getSQLByCampaignID(campaignID string) (string, bool) {
|
|||
|
||||
// NewRedisQueryResults creates a new Redis implementation of the
|
||||
// QueryResultStore interface using the provided Redis connection pool.
|
||||
func NewRedisLiveQuery(pool fleet.RedisPool, logger kitlog.Logger, memCacheExp time.Duration) *redisLiveQuery {
|
||||
func NewRedisLiveQuery(pool fleet.RedisPool, logger *slog.Logger, memCacheExp time.Duration) *redisLiveQuery {
|
||||
return &redisLiveQuery{
|
||||
pool: pool,
|
||||
cache: newMemCache(),
|
||||
|
|
@ -251,7 +250,7 @@ func (r *redisLiveQuery) collectBatchQueriesForHost(hostID uint, queryKeys []str
|
|||
if sql, found := r.getSQLByCampaignID(name); found {
|
||||
queriesByHost[name] = sql
|
||||
} else {
|
||||
level.Warn(r.logger).Log("msg", "live query not found in cache", "name", name)
|
||||
r.logger.WarnContext(context.TODO(), "live query not found in cache", "name", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -426,7 +425,7 @@ func (r *redisLiveQuery) loadCache() error {
|
|||
go func() {
|
||||
err = r.removeQueryNames(names...)
|
||||
if err != nil {
|
||||
level.Warn(r.logger).Log("msg", "removing expired live queries", "err", err)
|
||||
r.logger.WarnContext(context.TODO(), "removing expired live queries", "err", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
package live_query
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/redis/redistest"
|
||||
"github.com/fleetdm/fleet/v4/server/test"
|
||||
"github.com/go-kit/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ func TestRedisLiveQuery(t *testing.T) {
|
|||
|
||||
func setupRedisLiveQuery(t *testing.T, cluster bool) *redisLiveQuery {
|
||||
pool := redistest.SetupRedis(t, "*livequery", cluster, true, true)
|
||||
return NewRedisLiveQuery(pool, log.NewNopLogger(), 0)
|
||||
return NewRedisLiveQuery(pool, slog.New(slog.DiscardHandler), 0)
|
||||
}
|
||||
|
||||
func TestMapBitfield(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -4,28 +4,27 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/redis"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/platform/logging"
|
||||
"github.com/go-kit/log/level"
|
||||
redigo "github.com/gomodule/redigo/redis"
|
||||
)
|
||||
|
||||
type redisQueryResults struct {
|
||||
pool fleet.RedisPool
|
||||
duplicateResults bool
|
||||
logger *logging.Logger
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
var _ fleet.QueryResultStore = &redisQueryResults{}
|
||||
|
||||
// NewRedisQueryResults creats a new Redis implementation of the
|
||||
// QueryResultStore interface using the provided Redis connection pool.
|
||||
func NewRedisQueryResults(pool fleet.RedisPool, duplicateResults bool, logger *logging.Logger) *redisQueryResults {
|
||||
func NewRedisQueryResults(pool fleet.RedisPool, duplicateResults bool, logger *slog.Logger) *redisQueryResults {
|
||||
return &redisQueryResults{
|
||||
pool: pool,
|
||||
duplicateResults: duplicateResults,
|
||||
|
|
@ -86,7 +85,7 @@ func writeOrDone(ctx context.Context, ch chan<- interface{}, item interface{}) b
|
|||
// connection over the provided channel. This effectively allows a select
|
||||
// statement to run on conn.Receive() (by selecting on outChan that is
|
||||
// passed into this function)
|
||||
func receiveMessages(ctx context.Context, conn *redigo.PubSubConn, outChan chan<- any, logger *logging.Logger) {
|
||||
func receiveMessages(ctx context.Context, conn *redigo.PubSubConn, outChan chan<- any, logger *slog.Logger) {
|
||||
defer close(outChan)
|
||||
|
||||
for {
|
||||
|
|
@ -97,7 +96,7 @@ func receiveMessages(ctx context.Context, conn *redigo.PubSubConn, outChan chan<
|
|||
msg := conn.ReceiveWithTimeout(1 * time.Hour)
|
||||
|
||||
if recvTime := time.Since(beforeReceive); recvTime > time.Minute {
|
||||
level.Info(logger).Log("msg", "conn.ReceiveWithTimeout connection was blocked for significant time", "duration", recvTime, "connection", fmt.Sprintf("%p", conn))
|
||||
logger.InfoContext(ctx, "conn.ReceiveWithTimeout connection was blocked for significant time", "duration", recvTime, "connection", fmt.Sprintf("%p", conn))
|
||||
}
|
||||
|
||||
// Pass the message back to ReadChannel.
|
||||
|
|
@ -108,7 +107,7 @@ func receiveMessages(ctx context.Context, conn *redigo.PubSubConn, outChan chan<
|
|||
switch msg := msg.(type) {
|
||||
case error:
|
||||
// If an error occurred (i.e. connection was closed), then we should exit.
|
||||
level.Error(logger).Log("msg", "conn.ReceiveWithTimeout failed", "err", msg)
|
||||
logger.ErrorContext(ctx, "conn.ReceiveWithTimeout failed", "err", msg)
|
||||
return
|
||||
case redigo.Subscription:
|
||||
// If the subscription count is 0, the ReadChannel call that invoked this goroutine has unsubscribed,
|
||||
|
|
@ -156,7 +155,7 @@ func (r *redisQueryResults) ReadChannel(ctx context.Context, query fleet.Distrib
|
|||
select {
|
||||
case msg, ok := <-msgChannel:
|
||||
if !ok {
|
||||
level.Error(logger).Log("msg", "unexpected exit in receiveMessages")
|
||||
logger.ErrorContext(ctx, "unexpected exit in receiveMessages")
|
||||
// NOTE(lucas): The below error string should not be modified. The UI is relying on it to detect
|
||||
// when Fleet's connection to Redis has been interrupted unexpectedly.
|
||||
//
|
||||
|
|
@ -179,7 +178,7 @@ func (r *redisQueryResults) ReadChannel(ctx context.Context, query fleet.Distrib
|
|||
return
|
||||
}
|
||||
case error:
|
||||
level.Error(logger).Log("msg", "error received from pubsub channel", "err", msg)
|
||||
logger.ErrorContext(ctx, "error received from pubsub channel", "err", msg)
|
||||
if writeOrDone(ctx, outChannel, ctxerr.Wrap(ctx, msg, "read from redis")) {
|
||||
return
|
||||
}
|
||||
|
|
@ -195,7 +194,7 @@ func (r *redisQueryResults) ReadChannel(ctx context.Context, query fleet.Distrib
|
|||
wg.Wait()
|
||||
psc.Unsubscribe(pubSubName) //nolint:errcheck
|
||||
conn.Close()
|
||||
level.Debug(logger).Log("msg", "proper close of Redis connection in ReadChannel", "connection", fmt.Sprintf("%p", conn))
|
||||
logger.DebugContext(ctx, "proper close of Redis connection in ReadChannel", "connection", fmt.Sprintf("%p", conn))
|
||||
}()
|
||||
|
||||
return outChannel, nil
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
package pubsub
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/redis/redistest"
|
||||
"github.com/fleetdm/fleet/v4/server/platform/logging"
|
||||
)
|
||||
|
||||
func SetupRedisForTest(t *testing.T, cluster, readReplica bool) *redisQueryResults {
|
||||
const dupResults = false
|
||||
pool := redistest.SetupRedis(t, "zz", cluster, false, readReplica)
|
||||
return NewRedisQueryResults(pool, dupResults, logging.NewNopLogger())
|
||||
return NewRedisQueryResults(pool, dupResults, slog.New(slog.DiscardHandler))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ func newTestServiceWithConfig(t *testing.T, ds fleet.Datastore, fleetConfig conf
|
|||
var eh *errorstore.Handler
|
||||
if len(opts) > 0 {
|
||||
if opts[0].Pool != nil {
|
||||
eh = errorstore.NewHandler(ctx, opts[0].Pool, logger, time.Minute*10)
|
||||
eh = errorstore.NewHandler(ctx, opts[0].Pool, logger.SlogLogger(), time.Minute*10)
|
||||
ctx = ctxerr.NewContext(ctx, eh)
|
||||
}
|
||||
if opts[0].StartCronSchedules != nil {
|
||||
|
|
@ -582,7 +582,7 @@ func RunServerForTestsWithServiceWithDS(t *testing.T, ctx context.Context, ds fl
|
|||
rootMux.Handle("/enroll", ServeEndUserEnrollOTA(svc, "", ds, logger))
|
||||
|
||||
if len(opts) > 0 && opts[0].EnableSCIM {
|
||||
require.NoError(t, scim.RegisterSCIM(rootMux, ds, svc, logger, &cfg))
|
||||
require.NoError(t, scim.RegisterSCIM(rootMux, ds, svc, logger.SlogLogger(), &cfg))
|
||||
rootMux.Handle("/api/v1/fleet/scim/details", apiHandler)
|
||||
rootMux.Handle("/api/latest/fleet/scim/details", apiHandler)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue