Incremental migration to slog (#40120)

<!-- Add the related story/sub-task/bug number, like Resolves #123, or
remove if NA -->
**Related issue:** Resolves #40054 

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [ ] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
  - Already added in previous PR

## Testing

- [x] Added/updated automated tests
- [x] QA'd all new/changed functionality manually

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Refactor**
* Updated internal logging infrastructure across multiple server
components to use standardized logging methods and improved context
propagation.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Victor Lyuboslavsky 2026-02-19 15:35:35 -06:00 committed by GitHub
parent c303f7f0e6
commit 70ffac6341
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 284 additions and 272 deletions

View file

@ -423,9 +423,9 @@ the way that the Fleet server works.
ds = redisWrapperDS
resultStore := pubsub.NewRedisQueryResults(redisPool, config.Redis.DuplicateResults,
logger.With("component", "query-results"),
logger.SlogLogger().With("component", "query-results"),
)
liveQueryStore := live_query.NewRedisLiveQuery(redisPool, logger, liveQueryMemCacheDuration)
liveQueryStore := live_query.NewRedisLiveQuery(redisPool, logger.SlogLogger(), liveQueryMemCacheDuration)
ssoSessionStore := sso.NewSessionStore(redisPool)
// Set common configuration for all logging.
@ -557,7 +557,7 @@ the way that the Fleet server works.
var geoIP fleet.GeoIP
geoIP = &fleet.NoOpGeoIP{}
if config.GeoIP.DatabasePath != "" {
maxmind, err := fleet.NewMaxMindGeoIP(logger, config.GeoIP.DatabasePath)
maxmind, err := fleet.NewMaxMindGeoIP(logger.SlogLogger(), config.GeoIP.DatabasePath)
if err != nil {
level.Error(logger).Log("msg", "failed to initialize maxmind geoip, check database path", "database_path",
config.GeoIP.DatabasePath, "error", err)
@ -853,7 +853,7 @@ the way that the Fleet server works.
}
}
eh := errorstore.NewHandler(ctx, redisPool, logger, config.Logging.ErrorRetentionPeriod)
eh := errorstore.NewHandler(ctx, redisPool, logger.SlogLogger(), config.Logging.ErrorRetentionPeriod)
scepConfigMgr := eeservice.NewSCEPConfigService(logger, nil)
digiCertService := digicert.NewService(digicert.WithLogger(logger))
ctx = ctxerr.NewContext(ctx, eh)
@ -1497,7 +1497,7 @@ the way that the Fleet server works.
if err = service.RegisterSCEPProxy(rootMux, ds, logger, nil, &config); err != nil {
initFatal(err, "setup SCEP proxy")
}
if err = scim.RegisterSCIM(rootMux, ds, svc, logger, &config); err != nil {
if err = scim.RegisterSCIM(rootMux, ds, svc, logger.SlogLogger(), &config); err != nil {
initFatal(err, "setup SCIM")
}
// Host identify and conditional access SCEP feature only works if a private key has been set up

View file

@ -3,6 +3,7 @@ package scim
import (
"context"
"fmt"
"log/slog"
"net/http"
"net/url"
"strconv"
@ -12,8 +13,6 @@ import (
"github.com/elimity-com/scim/errors"
"github.com/elimity-com/scim/optional"
"github.com/fleetdm/fleet/v4/server/fleet"
kitlog "github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/scim2/filter-parser/v2"
)
@ -25,13 +24,13 @@ const (
type GroupHandler struct {
ds fleet.Datastore
logger kitlog.Logger
logger *slog.Logger
}
// Compile-time check
var _ scim.ResourceHandler = &GroupHandler{}
func NewGroupHandler(ds fleet.Datastore, logger kitlog.Logger) scim.ResourceHandler {
func NewGroupHandler(ds fleet.Datastore, logger *slog.Logger) scim.ResourceHandler {
return &GroupHandler{ds: ds, logger: logger}
}
@ -39,7 +38,7 @@ func NewGroupHandler(ds fleet.Datastore, logger kitlog.Logger) scim.ResourceHand
func (g *GroupHandler) Create(r *http.Request, attributes scim.ResourceAttributes) (scim.Resource, error) {
displayName, err := getRequiredResource[string](attributes, displayNameAttr)
if err != nil {
level.Error(g.logger).Log("msg", "failed to get displayName", "err", err)
g.logger.ErrorContext(r.Context(), "failed to get displayName", "err", err)
return scim.Resource{}, err
}
@ -50,16 +49,16 @@ func (g *GroupHandler) Create(r *http.Request, attributes scim.ResourceAttribute
_, err = g.ds.ScimGroupByDisplayName(r.Context(), displayName)
switch {
case err != nil && !fleet.IsNotFound(err):
level.Error(g.logger).Log("msg", "failed to check for displayName uniqueness", displayNameAttr, displayName, "err", err)
g.logger.ErrorContext(r.Context(), "failed to check for displayName uniqueness", displayNameAttr, displayName, "err", err)
return scim.Resource{}, err
case err == nil:
level.Info(g.logger).Log("msg", "group already exists", displayNameAttr, displayName)
g.logger.InfoContext(r.Context(), "group already exists", displayNameAttr, displayName)
return scim.Resource{}, errors.ScimErrorUniqueness
}
group, err := createGroupFromAttributes(attributes)
if err != nil {
level.Error(g.logger).Log("msg", "failed to create group from attributes", displayNameAttr, displayName, "err", err)
g.logger.ErrorContext(r.Context(), "failed to create group from attributes", displayNameAttr, displayName, "err", err)
return scim.Resource{}, err
}
group.ID, err = g.ds.CreateScimGroup(r.Context(), group)
@ -135,17 +134,17 @@ func areMembersExcluded(r *http.Request) bool {
func (g *GroupHandler) Get(r *http.Request, id string) (scim.Resource, error) {
idUint, err := extractGroupIDFromValue(id)
if err != nil {
level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err)
g.logger.InfoContext(r.Context(), "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
group, err := g.ds.ScimGroupByID(r.Context(), idUint, areMembersExcluded(r))
switch {
case fleet.IsNotFound(err):
level.Info(g.logger).Log("msg", "failed to find group", "id", id)
g.logger.InfoContext(r.Context(), "failed to find group", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(g.logger).Log("msg", "failed to get group", "id", id, "err", err)
g.logger.ErrorContext(r.Context(), "failed to get group", "id", id, "err", err)
return scim.Resource{}, err
}
@ -203,23 +202,23 @@ func (g *GroupHandler) GetAll(r *http.Request, params scim.ListRequestParams) (s
if resourceFilter != "" {
expr, err := filter.ParseAttrExp([]byte(resourceFilter))
if err != nil {
level.Error(g.logger).Log("msg", "failed to parse filter", "filter", resourceFilter, "err", err)
g.logger.ErrorContext(r.Context(), "failed to parse filter", "filter", resourceFilter, "err", err)
return scim.Page{}, errors.ScimErrorInvalidFilter
}
if !strings.EqualFold(expr.AttributePath.String(), "displayName") || expr.Operator != "eq" {
level.Info(g.logger).Log("msg", "unsupported filter", "filter", resourceFilter)
g.logger.InfoContext(r.Context(), "unsupported filter", "filter", resourceFilter)
return scim.Page{}, nil
}
displayName, ok := expr.CompareValue.(string)
if !ok {
level.Error(g.logger).Log("msg", "unsupported value", "value", expr.CompareValue)
g.logger.ErrorContext(r.Context(), "unsupported value", "value", expr.CompareValue)
return scim.Page{}, nil
}
// Decode URL-encoded characters
displayName, err = url.QueryUnescape(displayName)
if err != nil {
level.Error(g.logger).Log("msg", "failed to decode displayName", "displayName", displayName, "err", err)
g.logger.ErrorContext(r.Context(), "failed to decode displayName", "displayName", displayName, "err", err)
return scim.Page{}, nil
}
opts.DisplayNameFilter = &displayName
@ -227,7 +226,7 @@ func (g *GroupHandler) GetAll(r *http.Request, params scim.ListRequestParams) (s
groups, totalResults, err := g.ds.ListScimGroups(r.Context(), opts)
if err != nil {
level.Error(g.logger).Log("msg", "failed to list groups", "err", err)
g.logger.ErrorContext(r.Context(), "failed to list groups", "err", err)
return scim.Page{}, err
}
@ -245,13 +244,13 @@ func (g *GroupHandler) GetAll(r *http.Request, params scim.ListRequestParams) (s
func (g *GroupHandler) Replace(r *http.Request, id string, attributes scim.ResourceAttributes) (scim.Resource, error) {
idUint, err := extractGroupIDFromValue(id)
if err != nil {
level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err)
g.logger.InfoContext(r.Context(), "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
group, err := createGroupFromAttributes(attributes)
if err != nil {
level.Error(g.logger).Log("msg", "failed to create group from attributes", "id", id, "err", err)
g.logger.ErrorContext(r.Context(), "failed to create group from attributes", "id", id, "err", err)
return scim.Resource{}, err
}
group.ID = idUint
@ -260,10 +259,10 @@ func (g *GroupHandler) Replace(r *http.Request, id string, attributes scim.Resou
groupWithSameDisplayName, err := g.ds.ScimGroupByDisplayName(r.Context(), group.DisplayName)
switch {
case err != nil && !fleet.IsNotFound(err):
level.Error(g.logger).Log("msg", "failed to check for displayName uniqueness", displayNameAttr, group.DisplayName, "err", err)
g.logger.ErrorContext(r.Context(), "failed to check for displayName uniqueness", displayNameAttr, group.DisplayName, "err", err)
return scim.Resource{}, err
case err == nil && group.ID != groupWithSameDisplayName.ID:
level.Info(g.logger).Log("msg", "group already exists with this displayName", displayNameAttr, group.DisplayName)
g.logger.InfoContext(r.Context(), "group already exists with this displayName", displayNameAttr, group.DisplayName)
return scim.Resource{}, errors.ScimErrorUniqueness
// Otherwise, we assume that we are replacing the displayName with this operation.
}
@ -271,10 +270,10 @@ func (g *GroupHandler) Replace(r *http.Request, id string, attributes scim.Resou
err = g.ds.ReplaceScimGroup(r.Context(), group)
switch {
case fleet.IsNotFound(err):
level.Info(g.logger).Log("msg", "failed to find group to replace", "id", id)
g.logger.InfoContext(r.Context(), "failed to find group to replace", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(g.logger).Log("msg", "failed to replace group", "id", id, "err", err)
g.logger.ErrorContext(r.Context(), "failed to replace group", "id", id, "err", err)
return scim.Resource{}, err
}
@ -284,16 +283,16 @@ func (g *GroupHandler) Replace(r *http.Request, id string, attributes scim.Resou
func (g *GroupHandler) Delete(r *http.Request, id string) error {
idUint, err := extractGroupIDFromValue(id)
if err != nil {
level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err)
g.logger.InfoContext(r.Context(), "failed to parse id", "id", id, "err", err)
return errors.ScimErrorResourceNotFound(id)
}
err = g.ds.DeleteScimGroup(r.Context(), idUint)
switch {
case fleet.IsNotFound(err):
level.Info(g.logger).Log("msg", "failed to find group to delete", "id", id)
g.logger.InfoContext(r.Context(), "failed to find group to delete", "id", id)
return errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(g.logger).Log("msg", "failed to delete group", "id", id, "err", err)
g.logger.ErrorContext(r.Context(), "failed to delete group", "id", id, "err", err)
return err
}
return nil
@ -302,93 +301,94 @@ func (g *GroupHandler) Delete(r *http.Request, id string) error {
// Patch
// Supporting add/replace/remove operations for "displayName", "externalId", and "members" attributes.
func (g *GroupHandler) Patch(r *http.Request, id string, operations []scim.PatchOperation) (scim.Resource, error) {
ctx := r.Context()
idUint, err := extractGroupIDFromValue(id)
if err != nil {
level.Info(g.logger).Log("msg", "failed to parse id", "id", id, "err", err)
g.logger.InfoContext(ctx, "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
group, err := g.ds.ScimGroupByID(r.Context(), idUint, false)
group, err := g.ds.ScimGroupByID(ctx, idUint, false)
switch {
case fleet.IsNotFound(err):
level.Info(g.logger).Log("msg", "failed to find group to patch", "id", id)
g.logger.InfoContext(ctx, "failed to find group to patch", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(g.logger).Log("msg", "failed to get group to patch", "id", id, "err", err)
g.logger.ErrorContext(ctx, "failed to get group to patch", "id", id, "err", err)
return scim.Resource{}, err
}
for _, op := range operations {
if op.Op != scim.PatchOperationAdd && op.Op != scim.PatchOperationReplace && op.Op != scim.PatchOperationRemove {
level.Info(g.logger).Log("msg", "unsupported patch operation", "op", op.Op)
g.logger.InfoContext(ctx, "unsupported patch operation", "op", op.Op)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
switch {
case op.Path == nil:
if op.Op == scim.PatchOperationRemove {
level.Info(g.logger).Log("msg", "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op)
g.logger.InfoContext(ctx, "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
newValues, ok := op.Value.(map[string]interface{})
if !ok {
level.Info(g.logger).Log("msg", "unsupported patch value", "value", op.Value)
g.logger.InfoContext(ctx, "unsupported patch value", "value", op.Value)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
for k, v := range newValues {
switch k {
case externalIdAttr:
err = g.patchExternalId(op.Op, v, group)
err = g.patchExternalId(ctx, op.Op, v, group)
if err != nil {
return scim.Resource{}, err
}
case displayNameAttr:
err = g.patchDisplayName(op.Op, v, group)
err = g.patchDisplayName(ctx, op.Op, v, group)
if err != nil {
return scim.Resource{}, err
}
case membersAttr:
err = g.patchMembers(r.Context(), op.Op, v, group)
err = g.patchMembers(ctx, op.Op, v, group)
if err != nil {
return scim.Resource{}, err
}
default:
level.Info(g.logger).Log("msg", "unsupported patch value field", "field", k)
g.logger.InfoContext(ctx, "unsupported patch value field", "field", k)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
}
case op.Path.String() == externalIdAttr:
err = g.patchExternalId(op.Op, op.Value, group)
err = g.patchExternalId(ctx, op.Op, op.Value, group)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == displayNameAttr:
err = g.patchDisplayName(op.Op, op.Value, group)
err = g.patchDisplayName(ctx, op.Op, op.Value, group)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == membersAttr:
err = g.patchMembers(r.Context(), op.Op, op.Value, group)
err = g.patchMembers(ctx, op.Op, op.Value, group)
if err != nil {
return scim.Resource{}, err
}
case op.Path.AttributePath.String() == membersAttr:
err = g.patchMembersWithPathFiltering(r.Context(), op, group)
err = g.patchMembersWithPathFiltering(ctx, op, group)
if err != nil {
return scim.Resource{}, err
}
default:
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path)
g.logger.InfoContext(ctx, "unsupported patch path", "path", op.Path)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
}
if len(operations) != 0 {
err = g.ds.ReplaceScimGroup(r.Context(), group)
err = g.ds.ReplaceScimGroup(ctx, group)
switch {
case fleet.IsNotFound(err):
level.Info(g.logger).Log("msg", "failed to find group to patch", "id", id)
g.logger.InfoContext(ctx, "failed to find group to patch", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(g.logger).Log("msg", "failed to patch group", "id", id, "err", err)
g.logger.ErrorContext(ctx, "failed to patch group", "id", id, "err", err)
return scim.Resource{}, err
}
}
@ -396,32 +396,32 @@ func (g *GroupHandler) Patch(r *http.Request, id string, operations []scim.Patch
return createGroupResource(group), nil
}
func (g *GroupHandler) patchExternalId(op string, v interface{}, group *fleet.ScimGroup) error {
func (g *GroupHandler) patchExternalId(ctx context.Context, op string, v any, group *fleet.ScimGroup) error {
if op == scim.PatchOperationRemove || v == nil {
group.ExternalID = nil
return nil
}
externalId, ok := v.(string)
if !ok {
level.Info(g.logger).Log("msg", fmt.Sprintf("unsupported '%s' value", externalIdAttr), "value", v)
g.logger.InfoContext(ctx, fmt.Sprintf("unsupported '%s' value", externalIdAttr), "value", v)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
}
group.ExternalID = &externalId
return nil
}
func (g *GroupHandler) patchDisplayName(op string, v interface{}, group *fleet.ScimGroup) error {
func (g *GroupHandler) patchDisplayName(ctx context.Context, op string, v any, group *fleet.ScimGroup) error {
if op == scim.PatchOperationRemove {
level.Info(g.logger).Log("msg", "cannot remove required attribute", "attribute", displayNameAttr)
g.logger.InfoContext(ctx, "cannot remove required attribute", "attribute", displayNameAttr)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
displayName, ok := v.(string)
if !ok {
level.Info(g.logger).Log("msg", fmt.Sprintf("unsupported '%s' value", displayNameAttr), "value", v)
g.logger.InfoContext(ctx, fmt.Sprintf("unsupported '%s' value", displayNameAttr), "value", v)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
}
if displayName == "" {
level.Info(g.logger).Log("msg", fmt.Sprintf("'%s' cannot be empty", displayNameAttr), "value", v)
g.logger.InfoContext(ctx, fmt.Sprintf("'%s' cannot be empty", displayNameAttr), "value", v)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
}
group.DisplayName = displayName
@ -453,7 +453,7 @@ func (g *GroupHandler) patchMembers(ctx context.Context, op string, v interface{
membersList = append(membersList, m)
}
default:
level.Info(g.logger).Log("msg", "unsupported members value format", "value", v)
g.logger.InfoContext(ctx, "unsupported members value format", "value", v)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
}
@ -464,20 +464,20 @@ func (g *GroupHandler) patchMembers(ctx context.Context, op string, v interface{
for _, memberIntf := range membersList {
member, ok := memberIntf.(map[string]interface{})
if !ok {
level.Info(g.logger).Log("msg", "member must be an object", "member", memberIntf)
g.logger.InfoContext(ctx, "member must be an object", "member", memberIntf)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", memberIntf)})
}
// Get the value attribute which contains the user ID
valueIntf, ok := member["value"]
if !ok || valueIntf == nil {
level.Info(g.logger).Log("msg", "member missing value attribute", "member", member)
g.logger.InfoContext(ctx, "member missing value attribute", "member", member)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", member)})
}
valueStr, ok := valueIntf.(string)
if !ok {
level.Info(g.logger).Log("msg", "member value must be a string", "value", valueIntf)
g.logger.InfoContext(ctx, "member value must be a string", "value", valueIntf)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", valueIntf)})
}
valueStrings = append(valueStrings, valueStr)
@ -485,7 +485,7 @@ func (g *GroupHandler) patchMembers(ctx context.Context, op string, v interface{
// Extract user ID from the value
userID, err := extractUserIDFromValue(valueStr)
if err != nil {
level.Info(g.logger).Log("msg", "invalid user ID format", "value", valueStr, "err", err)
g.logger.InfoContext(ctx, "invalid user ID format", "value", valueStr, "err", err)
return errors.ScimErrorBadParams([]string{valueStr})
}
@ -496,11 +496,11 @@ func (g *GroupHandler) patchMembers(ctx context.Context, op string, v interface{
if len(userIDs) > 0 {
allExist, err := g.ds.ScimUsersExist(ctx, userIDs)
if err != nil {
level.Error(g.logger).Log("msg", "error checking users existence", "err", err)
g.logger.ErrorContext(ctx, "error checking users existence", "err", err)
return err
}
if !allExist {
level.Info(g.logger).Log("msg", "one or more users not found", "userIDs", userIDs)
g.logger.InfoContext(ctx, "one or more users not found", "userIDs", userIDs)
return errors.ScimErrorBadParams(valueStrings)
}
}
@ -531,7 +531,7 @@ func (g *GroupHandler) patchMembers(ctx context.Context, op string, v interface{
// patchMembersWithPathFiltering handles patch operations with path filtering for members
// This supports paths like members[value eq "422"] for add/replace/remove operations
func (g *GroupHandler) patchMembersWithPathFiltering(ctx context.Context, op scim.PatchOperation, group *fleet.ScimGroup) error {
memberID, err := g.getMemberID(op)
memberID, err := g.getMemberID(ctx, op)
if err != nil {
return err
}
@ -550,7 +550,7 @@ func (g *GroupHandler) patchMembersWithPathFiltering(ctx context.Context, op sci
// For remove operations, remove the member if found
if op.Op == scim.PatchOperationRemove {
if !memberFound {
level.Info(g.logger).Log("msg", "member not found in group", "member_id", memberID, "op", fmt.Sprintf("%v", op))
g.logger.InfoContext(ctx, "member not found in group", "member_id", memberID, "op", fmt.Sprintf("%v", op))
// The member may have been removed already from this group. For example, if the member was deleted.
return nil
}
@ -563,11 +563,11 @@ func (g *GroupHandler) patchMembersWithPathFiltering(ctx context.Context, op sci
// Verify the user exists
userExists, err := g.ds.ScimUsersExist(ctx, []uint{memberID})
if err != nil {
level.Error(g.logger).Log("msg", "error checking user existence", "err", err)
g.logger.ErrorContext(ctx, "error checking user existence", "err", err)
return err
}
if !userExists {
level.Info(g.logger).Log("msg", "user not found", "user_id", memberID)
g.logger.InfoContext(ctx, "user not found", "user_id", memberID)
return errors.ScimErrorBadParams([]string{scimUserID(memberID)})
}
group.ScimUsers = append(group.ScimUsers, memberID)
@ -577,7 +577,9 @@ func (g *GroupHandler) patchMembersWithPathFiltering(ctx context.Context, op sci
// For replace operations with a value
if op.Op == scim.PatchOperationReplace {
if !memberFound {
level.Info(g.logger).Log("msg", "member not found for replace operation", "members.value", memberID, "op", fmt.Sprintf("%v", op))
g.logger.InfoContext(
ctx, "member not found for replace operation", "members.value", memberID, "op", fmt.Sprintf("%v", op),
)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
@ -596,29 +598,29 @@ func (g *GroupHandler) patchMembersWithPathFiltering(ctx context.Context, op sci
}
// getMemberID extracts the member ID from a path expression like members[value eq "422"]
func (g *GroupHandler) getMemberID(op scim.PatchOperation) (uint, error) {
func (g *GroupHandler) getMemberID(ctx context.Context, op scim.PatchOperation) (uint, error) {
attrExpression, ok := op.Path.ValueExpression.(*filter.AttributeExpression)
if !ok {
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path)
g.logger.InfoContext(ctx, "unsupported patch path", "path", op.Path)
return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// Only matching by member value (user ID) is supported
if attrExpression.AttributePath.String() != valueAttr || attrExpression.Operator != filter.EQ {
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path, "expression", attrExpression.AttributePath.String())
g.logger.InfoContext(ctx, "unsupported patch path", "path", op.Path, "expression", attrExpression.AttributePath.String())
return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
memberIDStr, ok := attrExpression.CompareValue.(string)
if !ok {
level.Info(g.logger).Log("msg", "unsupported patch path", "path", op.Path, "compare_value", attrExpression.CompareValue)
g.logger.InfoContext(ctx, "unsupported patch path", "path", op.Path, "compare_value", attrExpression.CompareValue)
return 0, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// Extract user ID from the value
userID, err := extractUserIDFromValue(memberIDStr)
if err != nil {
level.Info(g.logger).Log("msg", "invalid user ID format", "value", memberIDStr, "err", err)
g.logger.InfoContext(ctx, "invalid user ID format", "value", memberIDStr, "err", err)
return 0, errors.ScimErrorBadParams([]string{memberIDStr})
}

View file

@ -2,10 +2,12 @@ package scim
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
@ -16,11 +18,8 @@ import (
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/platform/logging"
"github.com/fleetdm/fleet/v4/server/service/middleware/auth"
"github.com/fleetdm/fleet/v4/server/service/middleware/log"
kitlog "github.com/go-kit/log"
"github.com/go-kit/log/level"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
)
@ -32,7 +31,7 @@ func RegisterSCIM(
mux *http.ServeMux,
ds fleet.Datastore,
svc fleet.Service,
logger *logging.Logger,
logger *slog.Logger,
fleetConfig *config.FleetConfig,
) error {
if fleetConfig == nil {
@ -201,7 +200,7 @@ func RegisterSCIM(
}
serverOpts := []scim.ServerOption{
scim.WithLogger(&scimErrorLogger{Logger: scimLogger}),
scim.WithLogger(&scimErrorLogger{logger: scimLogger}),
}
server, err := scim.NewServer(serverArgs, serverOpts...)
@ -223,7 +222,7 @@ func RegisterSCIM(
handler = AuthorizationMiddleware(authorizer, scimLogger, handler)
handler = auth.AuthenticatedUserMiddleware(svc, scimErrorHandler, handler)
handler = LastRequestMiddleware(ds, scimLogger, handler)
handler = log.LogResponseEndMiddleware(scimLogger.SlogLogger(), handler)
handler = log.LogResponseEndMiddleware(scimLogger, handler)
handler = auth.SetRequestsContextMiddleware(svc, handler)
return handler
}
@ -312,7 +311,7 @@ func scimOTELMiddleware(next http.Handler, prefix string, cfg config.FleetConfig
// LastRequestMiddleware saves the details of the last request to SCIM endpoints in the datastore.
// These details can be used as a debug tool by the Fleet admin to see if SCIM integration is working.
func LastRequestMiddleware(ds fleet.Datastore, logger kitlog.Logger, next http.Handler) http.Handler {
func LastRequestMiddleware(ds fleet.Datastore, logger *slog.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
multi := newMultiResponseWriter(w)
next.ServeHTTP(multi, r)
@ -323,8 +322,7 @@ func LastRequestMiddleware(ds fleet.Datastore, logger kitlog.Logger, next http.H
status = "success"
case multi.statusCode == http.StatusUnauthorized:
// We do not save unauthenticated error details; we simply log them.
level.Info(logger).Log(
"msg", "unauthenticated request",
logger.InfoContext(r.Context(), "unauthenticated request",
"origin", r.Header.Get("Origin"),
"ip", r.RemoteAddr,
"method", r.Method,
@ -350,7 +348,7 @@ func LastRequestMiddleware(ds fleet.Datastore, logger kitlog.Logger, next http.H
default:
status = "error"
details = fmt.Sprintf("Unhandled status code: %d", multi.statusCode)
level.Error(logger).Log("msg", "unhandled status code", "status", multi.statusCode, "body", multi.body.String())
logger.ErrorContext(r.Context(), "unhandled status code", "status", multi.statusCode, "body", multi.body.String())
}
if len(details) > fleet.SCIMMaxFieldLength {
details = details[:fleet.SCIMMaxFieldLength]
@ -360,12 +358,12 @@ func LastRequestMiddleware(ds fleet.Datastore, logger kitlog.Logger, next http.H
Details: details,
})
if err != nil {
level.Error(logger).Log("msg", "failed to update last scim request", "err", err)
logger.ErrorContext(r.Context(), "failed to update last scim request", "err", err)
}
})
}
func AuthorizationMiddleware(authorizer *authz.Authorizer, logger kitlog.Logger, next http.Handler) http.Handler {
func AuthorizationMiddleware(authorizer *authz.Authorizer, logger *slog.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := authorizer.Authorize(r.Context(), &fleet.ScimUser{}, fleet.ActionWrite)
if err != nil {
@ -376,14 +374,14 @@ func AuthorizationMiddleware(authorizer *authz.Authorizer, logger kitlog.Logger,
})
}
func errorHandler(w http.ResponseWriter, logger kitlog.Logger, detail string, status int) {
func errorHandler(w http.ResponseWriter, logger *slog.Logger, detail string, status int) {
scimErr := scimerrors.ScimError{
Status: status,
Detail: detail,
}
raw, err := json.Marshal(scimErr)
if err != nil {
level.Error(logger).Log("msg", "failed marshaling scim error", "scimError", scimErr, "err", err)
logger.ErrorContext(context.TODO(), "failed marshaling scim error", "scimError", scimErr, "err", err)
return
}
@ -391,20 +389,18 @@ func errorHandler(w http.ResponseWriter, logger kitlog.Logger, detail string, st
w.WriteHeader(scimErr.Status)
_, err = w.Write(raw)
if err != nil {
level.Error(logger).Log("msg", "failed writing response", "err", err)
logger.ErrorContext(context.TODO(), "failed writing response", "err", err)
}
}
type scimErrorLogger struct {
kitlog.Logger
logger *slog.Logger
}
var _ scim.Logger = &scimErrorLogger{}
func (l *scimErrorLogger) Error(args ...interface{}) {
level.Error(l.Logger).Log(
"error", fmt.Sprint(args...),
)
l.logger.ErrorContext(context.TODO(), fmt.Sprint(args...))
}
type multiResponseWriter struct {

View file

@ -3,6 +3,7 @@ package scim
import (
"context"
"fmt"
"log/slog"
"net/http"
"net/url"
"slices"
@ -18,8 +19,6 @@ import (
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/service/modules/activities"
kitlog "github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/scim2/filter-parser/v2"
)
@ -46,32 +45,34 @@ const (
type UserHandler struct {
ds fleet.Datastore
activityModule activities.ActivityModule
logger kitlog.Logger
logger *slog.Logger
}
// Compile-time check
var _ scim.ResourceHandler = &UserHandler{}
func NewUserHandler(ds fleet.Datastore, activityModule activities.ActivityModule, logger kitlog.Logger) scim.ResourceHandler {
func NewUserHandler(ds fleet.Datastore, activityModule activities.ActivityModule, logger *slog.Logger) scim.ResourceHandler {
return &UserHandler{ds: ds, activityModule: activityModule, logger: logger}
}
func (u *UserHandler) Create(r *http.Request, attributes scim.ResourceAttributes) (scim.Resource, error) {
ctx := r.Context()
// Check for userName uniqueness
userName, err := getRequiredResource[string](attributes, userNameAttr)
if err != nil {
level.Error(u.logger).Log("msg", "failed to get userName", "err", err)
u.logger.ErrorContext(ctx, "failed to get userName", "err", err)
return scim.Resource{}, err
}
// In IETF documents, “non-empty” is generally used in the literal sense of “having at least one character.” That means if a value contains one or more spaces (and nothing else), it is still considered non-empty.
if len(userName) == 0 {
level.Info(u.logger).Log("msg", "userName is empty")
u.logger.InfoContext(ctx, "userName is empty")
return scim.Resource{}, errors.ScimErrorBadParams([]string{userNameAttr})
}
existingUser, err := u.ds.ScimUserByUserName(r.Context(), userName)
existingUser, err := u.ds.ScimUserByUserName(ctx, userName)
switch {
case err != nil && !fleet.IsNotFound(err):
level.Error(u.logger).Log("msg", "failed to check for userName uniqueness", userNameAttr, userName, "err", err)
u.logger.ErrorContext(ctx, "failed to check for userName uniqueness", userNameAttr, userName, "err", err)
return scim.Resource{}, err
case err == nil:
// User exists - check if it's a deactivated user being reactivated
@ -79,30 +80,31 @@ func (u *UserHandler) Create(r *http.Request, attributes scim.ResourceAttributes
incomingActive, _ := getOptionalResource[bool](attributes, activeAttr)
if existingUser.Active != nil && !*existingUser.Active && incomingActive != nil && *incomingActive {
// Reactivate the user by updating their record
level.Info(u.logger).Log("msg", "reactivating deactivated user", userNameAttr, userName)
user, err := u.createUserFromAttributes(attributes)
u.logger.InfoContext(ctx, "reactivating deactivated user", userNameAttr, userName)
user, err := u.createUserFromAttributes(ctx, attributes)
if err != nil {
level.Error(u.logger).Log("msg", "failed to create user from attributes for reactivation", userNameAttr, userName, "err", err)
u.logger.ErrorContext(ctx, "failed to create user from attributes for reactivation",
userNameAttr, userName, "err", err)
return scim.Resource{}, err
}
user.ID = existingUser.ID
err = u.ds.ReplaceScimUser(r.Context(), user)
err = u.ds.ReplaceScimUser(ctx, user)
if err != nil {
level.Error(u.logger).Log("msg", "failed to reactivate user", userNameAttr, userName, "err", err)
u.logger.ErrorContext(ctx, "failed to reactivate user", userNameAttr, userName, "err", err)
return scim.Resource{}, err
}
return createUserResource(user), nil
}
level.Info(u.logger).Log("msg", "user already exists", userNameAttr, userName)
u.logger.InfoContext(ctx, "user already exists", userNameAttr, userName)
return scim.Resource{}, errors.ScimErrorUniqueness
}
user, err := u.createUserFromAttributes(attributes)
user, err := u.createUserFromAttributes(ctx, attributes)
if err != nil {
level.Error(u.logger).Log("msg", "failed to create user from attributes", userNameAttr, userName, "err", err)
u.logger.ErrorContext(ctx, "failed to create user from attributes", userNameAttr, userName, "err", err)
return scim.Resource{}, err
}
user.ID, err = u.ds.CreateScimUser(r.Context(), user)
user.ID, err = u.ds.CreateScimUser(ctx, user)
if err != nil {
return scim.Resource{}, err
}
@ -110,7 +112,9 @@ func (u *UserHandler) Create(r *http.Request, attributes scim.ResourceAttributes
return createUserResource(user), nil
}
func (u *UserHandler) createUserFromAttributes(attributes scim.ResourceAttributes) (*fleet.ScimUser, error) {
func (u *UserHandler) createUserFromAttributes(
ctx context.Context, attributes scim.ResourceAttributes,
) (*fleet.ScimUser, error) {
user := fleet.ScimUser{}
var err error
user.UserName, err = getRequiredResource[string](attributes, userNameAttr)
@ -174,7 +178,7 @@ func (u *UserHandler) createUserFromAttributes(attributes scim.ResourceAttribute
user.Emails = userEmails
// Attempt to get extension enterprise user attributes.
extendedAttributes := u.getExtensionEnterpriseUserAttributes(user.UserName, attributes)
extendedAttributes := u.getExtensionEnterpriseUserAttributes(ctx, user.UserName, attributes)
user.Department = extendedAttributes.department
return &user, nil
@ -184,7 +188,9 @@ type extendedAttributes struct {
department *string
}
func (u *UserHandler) getExtensionEnterpriseUserAttributes(userName string, attributes scim.ResourceAttributes) extendedAttributes {
func (u *UserHandler) getExtensionEnterpriseUserAttributes(
ctx context.Context, userName string, attributes scim.ResourceAttributes,
) extendedAttributes {
var attrs extendedAttributes
m_, ok := attributes[extensionEnterpriseUserAttributes]
if !ok {
@ -192,8 +198,8 @@ func (u *UserHandler) getExtensionEnterpriseUserAttributes(userName string, attr
}
m, ok := m_.(map[string]any)
if !ok {
level.Error(u.logger).Log(
"msg", fmt.Sprintf("unexpected type for %s: %T", extensionEnterpriseUserAttributes, m_),
u.logger.ErrorContext(ctx,
fmt.Sprintf("unexpected type for %s: %T", extensionEnterpriseUserAttributes, m_),
userNameAttr, userName,
)
return attrs
@ -204,8 +210,8 @@ func (u *UserHandler) getExtensionEnterpriseUserAttributes(userName string, attr
if department, ok := department_.(string); ok {
attrs.department = &department
} else {
level.Error(u.logger).Log(
"msg", fmt.Sprintf("unexpected type for %s.department: %T", extensionEnterpriseUserAttributes, department_),
u.logger.ErrorContext(ctx,
fmt.Sprintf("unexpected type for %s.department: %T", extensionEnterpriseUserAttributes, department_),
userNameAttr, userName,
)
}
@ -275,19 +281,21 @@ func getComplexResourceSlice(attributes scim.ResourceAttributes, key string) ([]
}
func (u *UserHandler) Get(r *http.Request, id string) (scim.Resource, error) {
ctx := r.Context()
idUint, err := extractUserIDFromValue(id)
if err != nil {
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
u.logger.InfoContext(ctx, "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
user, err := u.ds.ScimUserByID(r.Context(), idUint)
user, err := u.ds.ScimUserByID(ctx, idUint)
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user", "id", id)
u.logger.InfoContext(ctx, "failed to find user", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to get user", "id", id, "err", err)
u.logger.ErrorContext(ctx, "failed to get user", "id", id, "err", err)
return scim.Resource{}, err
}
@ -365,6 +373,8 @@ func createUserResource(user *fleet.ScimUser) scim.Resource {
// totalResults: The total number of results returned by the list or query operation. The value may be larger than the number of
// resources returned, such as when returning a single page (see Section 3.4.2.4) of results where multiple pages are available.
func (u *UserHandler) GetAll(r *http.Request, params scim.ListRequestParams) (scim.Page, error) {
ctx := r.Context()
startIndex := params.StartIndex
if startIndex < 1 {
startIndex = 1
@ -387,30 +397,30 @@ func (u *UserHandler) GetAll(r *http.Request, params scim.ListRequestParams) (sc
if resourceFilter != "" {
expr, err := filter.ParseAttrExp([]byte(resourceFilter))
if err != nil {
level.Error(u.logger).Log("msg", "failed to parse filter", "filter", resourceFilter, "err", err)
u.logger.ErrorContext(ctx, "failed to parse filter", "filter", resourceFilter, "err", err)
return scim.Page{}, errors.ScimErrorInvalidFilter
}
if !strings.EqualFold(expr.AttributePath.String(), "userName") || expr.Operator != "eq" {
level.Info(u.logger).Log("msg", "unsupported filter", "filter", resourceFilter)
u.logger.InfoContext(ctx, "unsupported filter", "filter", resourceFilter)
return scim.Page{}, nil
}
userName, ok := expr.CompareValue.(string)
if !ok {
level.Error(u.logger).Log("msg", "unsupported value", "value", expr.CompareValue)
u.logger.ErrorContext(ctx, "unsupported value", "value", expr.CompareValue)
return scim.Page{}, nil
}
// Decode URL-encoded characters in userName, which is required to pass Microsoft Entra ID SCIM Validator
userName, err = url.QueryUnescape(userName)
if err != nil {
level.Error(u.logger).Log("msg", "failed to decode userName", "userName", userName, "err", err)
u.logger.ErrorContext(ctx, "failed to decode userName", "userName", userName, "err", err)
return scim.Page{}, nil
}
opts.UserNameFilter = &userName
}
users, totalResults, err := u.ds.ListScimUsers(r.Context(), opts)
users, totalResults, err := u.ds.ListScimUsers(ctx, opts)
if err != nil {
level.Error(u.logger).Log("msg", "failed to list users", "err", err)
u.logger.ErrorContext(ctx, "failed to list users", "err", err)
return scim.Page{}, err
}
@ -430,13 +440,13 @@ func (u *UserHandler) Replace(r *http.Request, id string, attributes scim.Resour
idUint, err := extractUserIDFromValue(id)
if err != nil {
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
u.logger.InfoContext(ctx, "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
user, err := u.createUserFromAttributes(attributes)
user, err := u.createUserFromAttributes(ctx, attributes)
if err != nil {
level.Error(u.logger).Log("msg", "failed to create user from attributes", "id", id, "err", err)
u.logger.ErrorContext(ctx, "failed to create user from attributes", "id", id, "err", err)
return scim.Resource{}, err
}
user.ID = idUint
@ -447,10 +457,10 @@ func (u *UserHandler) Replace(r *http.Request, id string, attributes scim.Resour
userWithSameUsername, err := u.ds.ScimUserByUserName(ctx, user.UserName)
switch {
case err != nil && !fleet.IsNotFound(err):
level.Error(u.logger).Log("msg", "failed to check for userName uniqueness", userNameAttr, user.UserName, "err", err)
u.logger.ErrorContext(ctx, "failed to check for userName uniqueness", userNameAttr, user.UserName, "err", err)
return scim.Resource{}, err
case err == nil && user.ID != userWithSameUsername.ID:
level.Info(u.logger).Log("msg", "user already exists with this username", userNameAttr, user.UserName)
u.logger.InfoContext(ctx, "user already exists with this username", userNameAttr, user.UserName)
return scim.Resource{}, errors.ScimErrorUniqueness
case err == nil && user.ID == userWithSameUsername.ID:
// Same user, username not changing - use this for previous active state
@ -459,11 +469,11 @@ func (u *UserHandler) Replace(r *http.Request, id string, attributes scim.Resour
// Username is being changed - need to fetch existing user by ID for previous active state
existingUser, err := u.ds.ScimUserByID(ctx, idUint)
if fleet.IsNotFound(err) {
level.Info(u.logger).Log("msg", "failed to find scim user by id", "id", id)
u.logger.InfoContext(ctx, "failed to find scim user by id", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
if err != nil {
level.Error(u.logger).Log("msg", "failed to get existing scim user by id", "id", id, "err", err)
u.logger.ErrorContext(ctx, "failed to get existing scim user by id", "id", id, "err", err)
return scim.Resource{}, err
}
previousActive = existingUser.Active
@ -472,17 +482,17 @@ func (u *UserHandler) Replace(r *http.Request, id string, attributes scim.Resour
err = u.ds.ReplaceScimUser(ctx, user)
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user to replace", "id", id)
u.logger.InfoContext(ctx, "failed to find user to replace", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to replace user", "id", id, "err", err)
u.logger.ErrorContext(ctx, "failed to replace user", "id", id, "err", err)
return scim.Resource{}, err
}
// Check if user was deactivated and delete matching Fleet user if so
if wasDeactivated(previousActive, user.Active) {
if err := u.deleteMatchingFleetUser(ctx, user); err != nil {
level.Error(u.logger).Log("msg", "failed to delete fleet user on deactivation", "err", err)
u.logger.ErrorContext(ctx, "failed to delete fleet user on deactivation", "err", err)
}
}
@ -497,33 +507,33 @@ func (u *UserHandler) Delete(r *http.Request, id string) error {
idUint, err := extractUserIDFromValue(id)
if err != nil {
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
u.logger.InfoContext(ctx, "failed to parse id", "id", id, "err", err)
return errors.ScimErrorResourceNotFound(id)
}
scimUser, err := u.ds.ScimUserByID(ctx, idUint)
if fleet.IsNotFound(err) {
// proceed with DeleteScimUser call which calls triggerResendProfilesForIDPUserDeleted even before checking if the user exists
level.Warn(u.logger).Log("msg", "scim user not found", "id", id)
u.logger.WarnContext(ctx, "scim user not found", "id", id)
} else if err != nil {
level.Error(u.logger).Log("msg", "failed to get scim user", "id", id, "err", err)
u.logger.ErrorContext(ctx, "failed to get scim user", "id", id, "err", err)
return err
}
if scimUser != nil {
if err := u.deleteMatchingFleetUser(ctx, scimUser); err != nil {
// Log but don't fail - SCIM deletion should still proceed
level.Error(u.logger).Log("msg", "failed to delete matching fleet user", "err", err)
u.logger.ErrorContext(ctx, "failed to delete matching fleet user", "err", err)
}
}
err = u.ds.DeleteScimUser(ctx, idUint)
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user to delete", "id", id)
u.logger.InfoContext(ctx, "failed to find user to delete", "id", id)
return errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to delete user", "id", id, "err", err)
u.logger.ErrorContext(ctx, "failed to delete user", "id", id, "err", err)
return err
}
@ -556,7 +566,7 @@ func (u *UserHandler) deleteMatchingFleetUser(ctx context.Context, scimUser *fle
emails = server.RemoveDuplicatesFromSlice(emails)
if len(emails) == 0 {
level.Debug(u.logger).Log("msg", "no emails found for scim user",
u.logger.DebugContext(ctx, "no emails found for scim user",
"scim_user_id", scimUser.ID, "user_name", scimUser.UserName)
return nil
}
@ -574,14 +584,14 @@ func (u *UserHandler) deleteMatchingFleetUser(ctx context.Context, scimUser *fle
}
if fleetUser == nil {
level.Debug(u.logger).Log("msg", "no matching fleet user found for scim user",
u.logger.DebugContext(ctx, "no matching fleet user found for scim user",
"scim_user_id", scimUser.ID, "user_name", scimUser.UserName)
return nil
}
// Skip API-only users or non-SSO users
if fleetUser.APIOnly || !fleetUser.SSOEnabled {
level.Info(u.logger).Log("msg", "skipping deletion of API-only or non-SSO user",
u.logger.InfoContext(ctx, "skipping deletion of API-only or non-SSO user",
"user_id", fleetUser.ID, "email", fleetUser.Email)
return nil
}
@ -594,13 +604,13 @@ func (u *UserHandler) deleteMatchingFleetUser(ctx context.Context, scimUser *fle
}
if count <= 1 {
level.Warn(u.logger).Log("msg", "cannot delete last global admin via SCIM",
u.logger.WarnContext(ctx, "cannot delete last global admin via SCIM",
"user_id", fleetUser.ID, "email", fleetUser.Email)
return ctxerr.New(ctx, "cannot delete last global admin")
}
}
level.Info(u.logger).Log("msg", "deleting fleet user via SCIM deletion",
u.logger.InfoContext(ctx, "deleting fleet user via SCIM deletion",
"user_id", fleetUser.ID, "email", fleetUser.Email)
// TODO: Ideally this should go through a Users service/module instead of directly accessing
@ -620,7 +630,7 @@ func (u *UserHandler) deleteMatchingFleetUser(ctx context.Context, scimUser *fle
FromScimUserDeletion: true,
},
); err != nil {
level.Error(u.logger).Log("msg", "failed to create activity for fleet user deletion", "err", err)
u.logger.ErrorContext(ctx, "failed to create activity for fleet user deletion", "err", err)
}
return nil
@ -632,16 +642,16 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
idUint, err := extractUserIDFromValue(id)
if err != nil {
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
u.logger.InfoContext(ctx, "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
user, err := u.ds.ScimUserByID(ctx, idUint)
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user to patch", "id", id)
u.logger.InfoContext(ctx, "failed to find user to patch", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to get user to patch", "id", id, "err", err)
u.logger.ErrorContext(ctx, "failed to get user to patch", "id", id, "err", err)
return scim.Resource{}, err
}
@ -650,115 +660,115 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
for _, op := range operations {
if op.Op != scim.PatchOperationAdd && op.Op != scim.PatchOperationReplace && op.Op != scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "unsupported patch operation", "op", op.Op)
u.logger.InfoContext(ctx, "unsupported patch operation", "op", op.Op)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
switch {
// If path is not specified, we look for the path in the value attribute.
case op.Path == nil:
if op.Op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op)
u.logger.InfoContext(ctx, "the 'path' attribute is REQUIRED for 'remove' operations", "op", op.Op)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
newValues, ok := op.Value.(map[string]interface{})
if !ok {
level.Info(u.logger).Log("msg", "unsupported patch value", "value", op.Value)
u.logger.InfoContext(ctx, "unsupported patch value", "value", op.Value)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
for k, v := range newValues {
switch k {
case externalIdAttr:
err = u.patchExternalId(op.Op, v, user)
err = u.patchExternalId(ctx, op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
case userNameAttr:
err = u.patchUserName(op.Op, v, user)
err = u.patchUserName(ctx, op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
case activeAttr:
err = u.patchActive(op.Op, v, user)
err = u.patchActive(ctx, op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
case nameAttr + "." + givenNameAttr:
err = u.patchGivenName(op.Op, v, user)
err = u.patchGivenName(ctx, op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
case nameAttr + "." + familyNameAttr:
err = u.patchFamilyName(op.Op, v, user)
err = u.patchFamilyName(ctx, op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
case nameAttr:
err = u.patchName(v, op, user)
err = u.patchName(ctx, v, op, user)
if err != nil {
return scim.Resource{}, err
}
case emailsAttr:
err = u.patchEmails(v, op, user)
err = u.patchEmails(ctx, v, op, user)
if err != nil {
return scim.Resource{}, err
}
case extensionEnterpriseUserAttributes + ":" + departmentAttr:
err = u.patchDepartment(op.Op, v, user)
err = u.patchDepartment(ctx, op.Op, v, user)
if err != nil {
return scim.Resource{}, err
}
default:
level.Info(u.logger).Log("msg", "unsupported patch value field", "field", k)
u.logger.InfoContext(ctx, "unsupported patch value field", "field", k)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
}
case op.Path.String() == externalIdAttr:
err = u.patchExternalId(op.Op, op.Value, user)
err = u.patchExternalId(ctx, op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == userNameAttr:
err = u.patchUserName(op.Op, op.Value, user)
err = u.patchUserName(ctx, op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == activeAttr:
err = u.patchActive(op.Op, op.Value, user)
err = u.patchActive(ctx, op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == nameAttr+"."+givenNameAttr:
err = u.patchGivenName(op.Op, op.Value, user)
err = u.patchGivenName(ctx, op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == nameAttr+"."+familyNameAttr:
err = u.patchFamilyName(op.Op, op.Value, user)
err = u.patchFamilyName(ctx, op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == nameAttr:
err = u.patchName(op.Value, op, user)
err = u.patchName(ctx, op.Value, op, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.String() == emailsAttr:
err = u.patchEmails(op.Value, op, user)
err = u.patchEmails(ctx, op.Value, op, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.AttributePath.String() == emailsAttr:
err = u.patchEmailsWithPathFiltering(op, user)
err = u.patchEmailsWithPathFiltering(ctx, op, user)
if err != nil {
return scim.Resource{}, err
}
case op.Path.AttributePath.String() == extensionEnterpriseUserAttributes+":"+departmentAttr:
err = u.patchDepartment(op.Op, op.Value, user)
err = u.patchDepartment(ctx, op.Op, op.Value, user)
if err != nil {
return scim.Resource{}, err
}
default:
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
u.logger.InfoContext(ctx, "unsupported patch path", "path", op.Path)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
}
@ -767,17 +777,17 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
err = u.ds.ReplaceScimUser(ctx, user)
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user to patch", "id", id)
u.logger.InfoContext(ctx, "failed to find user to patch", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to patch user", "id", id, "err", err)
u.logger.ErrorContext(ctx, "failed to patch user", "id", id, "err", err)
return scim.Resource{}, err
}
// Check if user was deactivated and delete matching Fleet user if so
if wasDeactivated(previousActive, user.Active) {
if err := u.deleteMatchingFleetUser(ctx, user); err != nil {
level.Error(u.logger).Log("msg", "failed to delete fleet user on deactivation", "err", err)
u.logger.ErrorContext(ctx, "failed to delete fleet user on deactivation", "err", err)
}
}
}
@ -785,8 +795,10 @@ func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchO
return createUserResource(user), nil
}
func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user *fleet.ScimUser) error {
emailType, err := u.getEmailType(op)
func (u *UserHandler) patchEmailsWithPathFiltering(
ctx context.Context, op scim.PatchOperation, user *fleet.ScimUser,
) error {
emailType, err := u.getEmailType(ctx, op)
if err != nil {
return err
}
@ -800,7 +812,7 @@ func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user
}
}
if !emailFound && op.Op != scim.PatchOperationAdd {
level.Info(u.logger).Log("msg", "email not found", "email_type", emailType, "op", fmt.Sprintf("%v", op))
u.logger.InfoContext(ctx, "email not found", "email_type", emailType, "op", fmt.Sprintf("%v", op))
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
if op.Path.SubAttribute == nil {
@ -820,7 +832,7 @@ func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user
// Single member as a map
emailsList = []interface{}{val}
default:
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
u.logger.InfoContext(ctx, fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
@ -831,10 +843,10 @@ func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user
return nil
}
if len(emailsList) != 1 {
level.Info(u.logger).Log("msg", "only 1 email should be present for replacement", "emails", emailsList)
u.logger.InfoContext(ctx, "only 1 email should be present for replacement", "emails", emailsList)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
userEmail, err := u.extractEmail(emailsList[0], op)
userEmail, err := u.extractEmail(ctx, emailsList[0], op)
if err != nil {
return err
}
@ -845,19 +857,19 @@ func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user
user.Emails[emailIndex] = userEmail
case scim.PatchOperationAdd:
if len(emailsList) == 0 {
level.Info(u.logger).Log("msg", "no emails provided to add", "emails", emailsList)
u.logger.InfoContext(ctx, "no emails provided to add", "emails", emailsList)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
var newEmails []fleet.ScimUserEmail
for e := range emailsList {
userEmail, err := u.extractEmail(emailsList[e], op)
userEmail, err := u.extractEmail(ctx, emailsList[e], op)
if err != nil {
return err
}
userEmail.Type = &emailType
newEmails = append(newEmails, userEmail)
}
primaryExists, err := u.checkEmailPrimary(newEmails)
primaryExists, err := u.checkEmailPrimary(ctx, newEmails)
if err != nil {
return err
}
@ -884,7 +896,7 @@ func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user
user.Emails[emailIndex].Primary = nil
return nil
}
primary, err := getConcreteType[bool](u, op.Value, primaryAttr)
primary, err := getConcreteType[bool](ctx, u, op.Value, primaryAttr)
if err != nil {
return err
}
@ -899,7 +911,7 @@ func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user
user.Emails[emailIndex].Email = ""
return nil
}
value, err := getConcreteType[string](u, op.Value, valueAttr)
value, err := getConcreteType[string](ctx, u, op.Value, valueAttr)
if err != nil {
return err
}
@ -913,53 +925,55 @@ func (u *UserHandler) patchEmailsWithPathFiltering(op scim.PatchOperation, user
user.Emails[emailIndex].Type = nil
return nil
}
newEmailType, err := getConcreteType[string](u, op.Value, typeAttr)
newEmailType, err := getConcreteType[string](ctx, u, op.Value, typeAttr)
if err != nil {
return err
}
user.Emails[emailIndex].Type = &newEmailType
default:
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
u.logger.InfoContext(ctx, "unsupported patch path", "path", op.Path)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
return nil
}
func (u *UserHandler) getEmailType(op scim.PatchOperation) (string, error) {
func (u *UserHandler) getEmailType(ctx context.Context, op scim.PatchOperation) (string, error) {
attrExpression, ok := op.Path.ValueExpression.(*filter.AttributeExpression)
if !ok {
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
u.logger.InfoContext(ctx, "unsupported patch path", "path", op.Path)
return "", errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// Only matching by email type (work, etc.) is supported.
if attrExpression.AttributePath.String() != typeAttr || attrExpression.Operator != filter.EQ {
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path, "expression", attrExpression.AttributePath.String())
u.logger.InfoContext(ctx, "unsupported patch path",
"path", op.Path, "expression", attrExpression.AttributePath.String())
return "", errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
emailType, ok := attrExpression.CompareValue.(string)
if !ok {
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path, "compare_value", attrExpression.CompareValue)
u.logger.InfoContext(ctx, "unsupported patch path",
"path", op.Path, "compare_value", attrExpression.CompareValue)
return "", errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
return emailType, nil
}
func getConcreteType[T string | bool](u *UserHandler, v interface{}, name string) (T, error) {
func getConcreteType[T string | bool](ctx context.Context, u *UserHandler, v any, name string) (T, error) {
concreteType, ok := v.(T)
if !ok {
var zeroValue T
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' value", name), "value", v)
u.logger.InfoContext(ctx, fmt.Sprintf("unsupported '%s' value", name), "value", v)
return zeroValue, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
}
return concreteType, nil
}
func (u *UserHandler) patchFamilyName(op string, v interface{}, user *fleet.ScimUser) error {
func (u *UserHandler) patchFamilyName(ctx context.Context, op string, v any, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr+"."+familyNameAttr)
u.logger.InfoContext(ctx, "cannot remove required attribute", "attribute", nameAttr+"."+familyNameAttr)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
familyName, err := getConcreteType[string](u, v, nameAttr+"."+familyNameAttr)
familyName, err := getConcreteType[string](ctx, u, v, nameAttr+"."+familyNameAttr)
if err != nil {
return err
}
@ -967,12 +981,12 @@ func (u *UserHandler) patchFamilyName(op string, v interface{}, user *fleet.Scim
return nil
}
func (u *UserHandler) patchGivenName(op string, v interface{}, user *fleet.ScimUser) error {
func (u *UserHandler) patchGivenName(ctx context.Context, op string, v any, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr+"."+givenNameAttr)
u.logger.InfoContext(ctx, "cannot remove required attribute", "attribute", nameAttr+"."+givenNameAttr)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
givenName, err := getConcreteType[string](u, v, nameAttr+"."+givenNameAttr)
givenName, err := getConcreteType[string](ctx, u, v, nameAttr+"."+givenNameAttr)
if err != nil {
return err
}
@ -980,12 +994,12 @@ func (u *UserHandler) patchGivenName(op string, v interface{}, user *fleet.ScimU
return nil
}
func (u *UserHandler) patchActive(op string, v interface{}, user *fleet.ScimUser) error {
func (u *UserHandler) patchActive(ctx context.Context, op string, v any, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove || v == nil {
user.Active = nil
return nil
}
active, err := getConcreteType[bool](u, v, activeAttr)
active, err := getConcreteType[bool](ctx, u, v, activeAttr)
if err != nil {
return err
}
@ -993,12 +1007,12 @@ func (u *UserHandler) patchActive(op string, v interface{}, user *fleet.ScimUser
return nil
}
func (u *UserHandler) patchExternalId(op string, v interface{}, user *fleet.ScimUser) error {
func (u *UserHandler) patchExternalId(ctx context.Context, op string, v any, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove || v == nil {
user.ExternalID = nil
return nil
}
externalId, err := getConcreteType[string](u, v, externalIdAttr)
externalId, err := getConcreteType[string](ctx, u, v, externalIdAttr)
if err != nil {
return err
}
@ -1006,29 +1020,29 @@ func (u *UserHandler) patchExternalId(op string, v interface{}, user *fleet.Scim
return nil
}
func (u *UserHandler) patchUserName(op string, v interface{}, user *fleet.ScimUser) error {
func (u *UserHandler) patchUserName(ctx context.Context, op string, v any, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", userNameAttr)
u.logger.InfoContext(ctx, "cannot remove required attribute", "attribute", userNameAttr)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
userName, err := getConcreteType[string](u, v, userNameAttr)
userName, err := getConcreteType[string](ctx, u, v, userNameAttr)
if err != nil {
return err
}
if userName == "" {
level.Info(u.logger).Log("msg", fmt.Sprintf("'%s' cannot be empty", userNameAttr), "value", v)
u.logger.InfoContext(ctx, fmt.Sprintf("'%s' cannot be empty", userNameAttr), "value", v)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", v)})
}
user.UserName = userName
return nil
}
func (u *UserHandler) patchDepartment(op string, v interface{}, user *fleet.ScimUser) error {
func (u *UserHandler) patchDepartment(ctx context.Context, op string, v any, user *fleet.ScimUser) error {
if op == scim.PatchOperationRemove || v == nil {
user.Department = nil
return nil
}
department, err := getConcreteType[string](u, v, departmentAttr)
department, err := getConcreteType[string](ctx, u, v, departmentAttr)
if err != nil {
return err
}
@ -1044,7 +1058,9 @@ func clearPrimaryFlagFromEmails(user *fleet.ScimUser) {
}
}
func (u *UserHandler) patchEmails(v interface{}, op scim.PatchOperation, user *fleet.ScimUser) error {
func (u *UserHandler) patchEmails(
ctx context.Context, v any, op scim.PatchOperation, user *fleet.ScimUser,
) error {
if op.Op == scim.PatchOperationRemove {
user.Emails = nil
return nil
@ -1061,24 +1077,24 @@ func (u *UserHandler) patchEmails(v interface{}, op scim.PatchOperation, user *f
// Single member as a map
emailsList = []interface{}{val}
default:
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
u.logger.InfoContext(ctx, fmt.Sprintf("unsupported '%s' patch value", emailsAttr), "value", op.Value)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
if op.Op == scim.PatchOperationAdd && len(emailsList) == 0 {
level.Info(u.logger).Log("msg", "no emails provided to add", "emails", emailsList)
u.logger.InfoContext(ctx, "no emails provided to add", "emails", emailsList)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// Convert the emails to the expected format
userEmails := make([]fleet.ScimUserEmail, 0, len(emailsList))
for _, emailIntf := range emailsList {
userEmail, err := u.extractEmail(emailIntf, op)
userEmail, err := u.extractEmail(ctx, emailIntf, op)
if err != nil {
return err
}
userEmails = append(userEmails, userEmail)
}
primaryExists, err := u.checkEmailPrimary(userEmails)
primaryExists, err := u.checkEmailPrimary(ctx, userEmails)
if err != nil {
return err
}
@ -1096,13 +1112,13 @@ func (u *UserHandler) patchEmails(v interface{}, op scim.PatchOperation, user *f
}
// checkEmailPrimary ensures at most one email is marked as primary
func (u *UserHandler) checkEmailPrimary(userEmails []fleet.ScimUserEmail) (bool, error) {
func (u *UserHandler) checkEmailPrimary(ctx context.Context, userEmails []fleet.ScimUserEmail) (bool, error) {
primaryEmailCount := 0
for _, email := range userEmails {
if email.Primary != nil && *email.Primary {
primaryEmailCount++
if primaryEmailCount > 1 {
level.Info(u.logger).Log("msg", "multiple primary emails found")
u.logger.InfoContext(ctx, "multiple primary emails found")
return false, errors.ScimErrorBadParams([]string{"Only one email can be marked as primary"})
}
}
@ -1110,24 +1126,26 @@ func (u *UserHandler) checkEmailPrimary(userEmails []fleet.ScimUserEmail) (bool,
return primaryEmailCount > 0, nil
}
func (u *UserHandler) extractEmail(emailIntf interface{}, op scim.PatchOperation) (fleet.ScimUserEmail, error) {
func (u *UserHandler) extractEmail(
ctx context.Context, emailIntf any, op scim.PatchOperation,
) (fleet.ScimUserEmail, error) {
emailMap, ok := emailIntf.(map[string]interface{})
if !ok {
level.Info(u.logger).Log("msg", "email is not a map", "email", emailIntf)
u.logger.InfoContext(ctx, "email is not a map", "email", emailIntf)
return fleet.ScimUserEmail{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// Extract the email value (required)
emailValue, ok := emailMap[valueAttr].(string)
if !ok || emailValue == "" {
level.Info(u.logger).Log("msg", "email value is missing or invalid", "email", emailMap)
u.logger.InfoContext(ctx, "email value is missing or invalid", "email", emailMap)
return fleet.ScimUserEmail{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
// Normalize the email
normalizedEmail, err := normalizeEmail(emailValue)
if err != nil {
level.Info(u.logger).Log("msg", "failed to normalize email", "email", emailValue, "err", err)
u.logger.InfoContext(ctx, "failed to normalize email", "email", emailValue, "err", err)
return fleet.ScimUserEmail{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
@ -1150,14 +1168,14 @@ func (u *UserHandler) extractEmail(emailIntf interface{}, op scim.PatchOperation
return userEmail, nil
}
func (u *UserHandler) patchName(v interface{}, op scim.PatchOperation, user *fleet.ScimUser) error {
func (u *UserHandler) patchName(ctx context.Context, v any, op scim.PatchOperation, user *fleet.ScimUser) error {
if op.Op == scim.PatchOperationRemove {
level.Info(u.logger).Log("msg", "cannot remove required attribute", "attribute", nameAttr)
u.logger.InfoContext(ctx, "cannot remove required attribute", "attribute", nameAttr)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
name, ok := v.(map[string]interface{})
if !ok {
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", nameAttr), "value", op.Value)
u.logger.InfoContext(ctx, fmt.Sprintf("unsupported '%s' patch value", nameAttr), "value", op.Value)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
for nameKey, nameValue := range name {
@ -1165,21 +1183,21 @@ func (u *UserHandler) patchName(v interface{}, op scim.PatchOperation, user *fle
case givenNameAttr:
givenName, ok := nameValue.(string)
if !ok {
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", nameAttr+"."+givenNameAttr), "value",
op.Value)
u.logger.InfoContext(ctx,
fmt.Sprintf("unsupported '%s' patch value", nameAttr+"."+givenNameAttr), "value", op.Value)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
user.GivenName = &givenName
case familyNameAttr:
familyName, ok := nameValue.(string)
if !ok {
level.Info(u.logger).Log("msg", fmt.Sprintf("unsupported '%s' patch value", nameAttr+"."+familyNameAttr), "value",
op.Value)
u.logger.InfoContext(ctx,
fmt.Sprintf("unsupported '%s' patch value", nameAttr+"."+familyNameAttr), "value", op.Value)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
user.FamilyName = &familyName
default:
level.Info(u.logger).Log("msg", "unsupported patch value field", "field", nameAttr+"."+nameKey)
u.logger.InfoContext(ctx, "unsupported patch value field", "field", nameAttr+"."+nameKey)
return errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
}

View file

@ -2,6 +2,7 @@ package scim
import (
"context"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
@ -12,7 +13,6 @@ import (
mockservice "github.com/fleetdm/fleet/v4/server/mock/service"
platform_mysql "github.com/fleetdm/fleet/v4/server/platform/mysql"
"github.com/fleetdm/fleet/v4/server/ptr"
kitlog "github.com/go-kit/log"
"github.com/scim2/filter-parser/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -34,7 +34,7 @@ func (m *testMocks) newTestHandler() *UserHandler {
return &UserHandler{
ds: m.ds,
activityModule: m.svc,
logger: kitlog.NewNopLogger(),
logger: slog.New(slog.DiscardHandler),
}
}

View file

@ -14,6 +14,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strconv"
"strings"
@ -23,8 +24,6 @@ import (
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
"github.com/fleetdm/fleet/v4/server/fleet"
kitlog "github.com/go-kit/log"
"github.com/go-kit/log/level"
redigo "github.com/gomodule/redigo/redis"
)
@ -33,7 +32,7 @@ import (
// from the store. It is safe to call those methods concurrently.
type Handler struct {
pool fleet.RedisPool
logger kitlog.Logger
logger *slog.Logger
ttl time.Duration
running int32 // accessed atomically
errCh chan error
@ -47,7 +46,7 @@ type Handler struct {
// NewHandler creates an error handler using the provided pool and logger,
// storing unique instances of errors in Redis using the pool. It stops storing
// errors when ctx is cancelled. Errors are kept for the duration of ttl.
func NewHandler(ctx context.Context, pool fleet.RedisPool, logger kitlog.Logger, ttl time.Duration) *Handler {
func NewHandler(ctx context.Context, pool fleet.RedisPool, logger *slog.Logger, ttl time.Duration) *Handler {
eh := &Handler{
pool: pool,
logger: logger,
@ -60,7 +59,7 @@ func NewHandler(ctx context.Context, pool fleet.RedisPool, logger kitlog.Logger,
return eh
}
func newTestHandler(ctx context.Context, pool fleet.RedisPool, logger kitlog.Logger, ttl time.Duration, onStart func(), onStore func(error)) *Handler {
func newTestHandler(ctx context.Context, pool fleet.RedisPool, logger *slog.Logger, ttl time.Duration, onStart func(), onStore func(error)) *Handler {
eh := &Handler{
pool: pool,
logger: logger,
@ -202,7 +201,7 @@ func (h *Handler) handleErrors(ctx context.Context) {
func (h *Handler) storeError(ctx context.Context, err error) {
errorHash, errorJson, err := hashAndMarshalError(err)
if err != nil {
level.Error(h.logger).Log("err", err, "msg", "hashErr failed")
h.logger.ErrorContext(ctx, "hashErr failed", "err", err)
if h.testOnStore != nil {
h.testOnStore(err)
}
@ -238,7 +237,7 @@ func (h *Handler) storeError(ctx context.Context, err error) {
}
if _, err := conn.Do(""); err != nil {
level.Error(h.logger).Log("err", err, "msg", "redis SET failed")
h.logger.ErrorContext(ctx, "redis SET failed", "err", err)
if h.testOnStore != nil {
h.testOnStore(err)
}

View file

@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http/httptest"
"os"
"regexp"
@ -18,7 +19,6 @@ import (
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/redis/redistest"
"github.com/fleetdm/fleet/v4/server/fleet"
kitlog "github.com/go-kit/log"
pkgErrors "github.com/pkg/errors" //nolint:depguard
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -150,7 +150,7 @@ func TestErrorHandler(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
eh := newTestHandler(ctx, nil, kitlog.NewNopLogger(), time.Minute, nil, nil)
eh := newTestHandler(ctx, nil, slog.New(slog.DiscardHandler), time.Minute, nil, nil)
doneCh := make(chan struct{})
go func() {
@ -168,7 +168,7 @@ func TestErrorHandler(t *testing.T) {
})
t.Run("works if the error storage is disabled", func(t *testing.T) {
eh := newTestHandler(context.Background(), nil, kitlog.NewNopLogger(), -1, nil, nil)
eh := newTestHandler(context.Background(), nil, slog.New(slog.DiscardHandler), -1, nil, nil)
doneCh := make(chan struct{})
go func() {
@ -218,7 +218,7 @@ func testErrorHandlerCollectsErrors(t *testing.T, pool fleet.RedisPool, wd strin
close(chDone)
}
}
eh := newTestHandler(ctx, pool, kitlog.NewNopLogger(), time.Minute, testOnStart, testOnStore)
eh := newTestHandler(ctx, pool, slog.New(slog.DiscardHandler), time.Minute, testOnStart, testOnStore)
<-chGo
@ -279,7 +279,7 @@ func testErrorHandlerCollectsDifferentErrors(t *testing.T, pool fleet.RedisPool,
}
}
eh := newTestHandler(ctx, pool, kitlog.NewNopLogger(), time.Minute, testOnStart, testOnStore)
eh := newTestHandler(ctx, pool, slog.New(slog.DiscardHandler), time.Minute, testOnStart, testOnStore)
<-chGo
@ -363,7 +363,7 @@ func TestHttpHandler(t *testing.T) {
}
}
eh := newTestHandler(ctx, pool, kitlog.NewNopLogger(), time.Minute, testOnStart, testOnStore)
eh := newTestHandler(ctx, pool, slog.New(slog.DiscardHandler), time.Minute, testOnStart, testOnStore)
<-chGo
// simulate two errors, one happening twice

View file

@ -3,10 +3,9 @@ package fleet
import (
"context"
"errors"
"log/slog"
"net"
"github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/oschwald/geoip2-golang"
)
@ -29,7 +28,7 @@ type GeoIP interface {
type MaxMindGeoIP struct {
reader *geoip2.Reader
l log.Logger
l *slog.Logger
}
type NoOpGeoIP struct{}
@ -38,7 +37,7 @@ func (n *NoOpGeoIP) Lookup(ctx context.Context, ip string) *GeoLocation {
return nil
}
func NewMaxMindGeoIP(logger log.Logger, path string) (*MaxMindGeoIP, error) {
func NewMaxMindGeoIP(logger *slog.Logger, path string) (*MaxMindGeoIP, error) {
r, err := geoip2.Open(path)
if err != nil {
return nil, err
@ -59,7 +58,7 @@ func (m *MaxMindGeoIP) Lookup(ctx context.Context, ip string) *GeoLocation {
if err != nil && errors.Is(err, notCityDBError) {
resp, err := m.reader.Country(parseIP)
if err != nil {
level.Debug(m.l).Log("err", err, "msg", "failed to lookup location from mmdb file")
m.l.DebugContext(ctx, "failed to lookup location from mmdb file", "err", err)
return nil
}
if resp == nil {
@ -69,7 +68,7 @@ func (m *MaxMindGeoIP) Lookup(ctx context.Context, ip string) *GeoLocation {
return &GeoLocation{CountryISO: resp.Country.IsoCode}
}
if err != nil {
level.Debug(m.l).Log("err", err, "msg", "failed to lookup location from mmdb file")
m.l.DebugContext(ctx, "failed to lookup location from mmdb file", "err", err)
return nil
}
return parseCity(resp)

View file

@ -48,6 +48,7 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"strconv"
"strings"
"sync"
@ -56,8 +57,6 @@ import (
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
"github.com/fleetdm/fleet/v4/server/fleet"
kitlog "github.com/go-kit/log"
"github.com/go-kit/log/level"
redigo "github.com/gomodule/redigo/redis"
)
@ -78,7 +77,7 @@ type redisLiveQuery struct {
// in memory cache expiration
cacheExpiration time.Duration
logger kitlog.Logger
logger *slog.Logger
}
// memCache is an in-memory cache for live queries. It stores the SQL of the
@ -109,7 +108,7 @@ func (r *redisLiveQuery) getSQLByCampaignID(campaignID string) (string, bool) {
// NewRedisQueryResults creates a new Redis implementation of the
// QueryResultStore interface using the provided Redis connection pool.
func NewRedisLiveQuery(pool fleet.RedisPool, logger kitlog.Logger, memCacheExp time.Duration) *redisLiveQuery {
func NewRedisLiveQuery(pool fleet.RedisPool, logger *slog.Logger, memCacheExp time.Duration) *redisLiveQuery {
return &redisLiveQuery{
pool: pool,
cache: newMemCache(),
@ -251,7 +250,7 @@ func (r *redisLiveQuery) collectBatchQueriesForHost(hostID uint, queryKeys []str
if sql, found := r.getSQLByCampaignID(name); found {
queriesByHost[name] = sql
} else {
level.Warn(r.logger).Log("msg", "live query not found in cache", "name", name)
r.logger.WarnContext(context.TODO(), "live query not found in cache", "name", name)
}
}
}
@ -426,7 +425,7 @@ func (r *redisLiveQuery) loadCache() error {
go func() {
err = r.removeQueryNames(names...)
if err != nil {
level.Warn(r.logger).Log("msg", "removing expired live queries", "err", err)
r.logger.WarnContext(context.TODO(), "removing expired live queries", "err", err)
}
}()
}

View file

@ -1,11 +1,11 @@
package live_query
import (
"log/slog"
"testing"
"github.com/fleetdm/fleet/v4/server/datastore/redis/redistest"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/go-kit/log"
"github.com/stretchr/testify/assert"
)
@ -27,7 +27,7 @@ func TestRedisLiveQuery(t *testing.T) {
func setupRedisLiveQuery(t *testing.T, cluster bool) *redisLiveQuery {
pool := redistest.SetupRedis(t, "*livequery", cluster, true, true)
return NewRedisLiveQuery(pool, log.NewNopLogger(), 0)
return NewRedisLiveQuery(pool, slog.New(slog.DiscardHandler), 0)
}
func TestMapBitfield(t *testing.T) {

View file

@ -4,28 +4,27 @@ import (
"context"
"encoding/json"
"fmt"
"log/slog"
"sync"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/platform/logging"
"github.com/go-kit/log/level"
redigo "github.com/gomodule/redigo/redis"
)
type redisQueryResults struct {
pool fleet.RedisPool
duplicateResults bool
logger *logging.Logger
logger *slog.Logger
}
var _ fleet.QueryResultStore = &redisQueryResults{}
// NewRedisQueryResults creats a new Redis implementation of the
// QueryResultStore interface using the provided Redis connection pool.
func NewRedisQueryResults(pool fleet.RedisPool, duplicateResults bool, logger *logging.Logger) *redisQueryResults {
func NewRedisQueryResults(pool fleet.RedisPool, duplicateResults bool, logger *slog.Logger) *redisQueryResults {
return &redisQueryResults{
pool: pool,
duplicateResults: duplicateResults,
@ -86,7 +85,7 @@ func writeOrDone(ctx context.Context, ch chan<- interface{}, item interface{}) b
// connection over the provided channel. This effectively allows a select
// statement to run on conn.Receive() (by selecting on outChan that is
// passed into this function)
func receiveMessages(ctx context.Context, conn *redigo.PubSubConn, outChan chan<- any, logger *logging.Logger) {
func receiveMessages(ctx context.Context, conn *redigo.PubSubConn, outChan chan<- any, logger *slog.Logger) {
defer close(outChan)
for {
@ -97,7 +96,7 @@ func receiveMessages(ctx context.Context, conn *redigo.PubSubConn, outChan chan<
msg := conn.ReceiveWithTimeout(1 * time.Hour)
if recvTime := time.Since(beforeReceive); recvTime > time.Minute {
level.Info(logger).Log("msg", "conn.ReceiveWithTimeout connection was blocked for significant time", "duration", recvTime, "connection", fmt.Sprintf("%p", conn))
logger.InfoContext(ctx, "conn.ReceiveWithTimeout connection was blocked for significant time", "duration", recvTime, "connection", fmt.Sprintf("%p", conn))
}
// Pass the message back to ReadChannel.
@ -108,7 +107,7 @@ func receiveMessages(ctx context.Context, conn *redigo.PubSubConn, outChan chan<
switch msg := msg.(type) {
case error:
// If an error occurred (i.e. connection was closed), then we should exit.
level.Error(logger).Log("msg", "conn.ReceiveWithTimeout failed", "err", msg)
logger.ErrorContext(ctx, "conn.ReceiveWithTimeout failed", "err", msg)
return
case redigo.Subscription:
// If the subscription count is 0, the ReadChannel call that invoked this goroutine has unsubscribed,
@ -156,7 +155,7 @@ func (r *redisQueryResults) ReadChannel(ctx context.Context, query fleet.Distrib
select {
case msg, ok := <-msgChannel:
if !ok {
level.Error(logger).Log("msg", "unexpected exit in receiveMessages")
logger.ErrorContext(ctx, "unexpected exit in receiveMessages")
// NOTE(lucas): The below error string should not be modified. The UI is relying on it to detect
// when Fleet's connection to Redis has been interrupted unexpectedly.
//
@ -179,7 +178,7 @@ func (r *redisQueryResults) ReadChannel(ctx context.Context, query fleet.Distrib
return
}
case error:
level.Error(logger).Log("msg", "error received from pubsub channel", "err", msg)
logger.ErrorContext(ctx, "error received from pubsub channel", "err", msg)
if writeOrDone(ctx, outChannel, ctxerr.Wrap(ctx, msg, "read from redis")) {
return
}
@ -195,7 +194,7 @@ func (r *redisQueryResults) ReadChannel(ctx context.Context, query fleet.Distrib
wg.Wait()
psc.Unsubscribe(pubSubName) //nolint:errcheck
conn.Close()
level.Debug(logger).Log("msg", "proper close of Redis connection in ReadChannel", "connection", fmt.Sprintf("%p", conn))
logger.DebugContext(ctx, "proper close of Redis connection in ReadChannel", "connection", fmt.Sprintf("%p", conn))
}()
return outChannel, nil

View file

@ -1,14 +1,14 @@
package pubsub
import (
"log/slog"
"testing"
"github.com/fleetdm/fleet/v4/server/datastore/redis/redistest"
"github.com/fleetdm/fleet/v4/server/platform/logging"
)
func SetupRedisForTest(t *testing.T, cluster, readReplica bool) *redisQueryResults {
const dupResults = false
pool := redistest.SetupRedis(t, "zz", cluster, false, readReplica)
return NewRedisQueryResults(pool, dupResults, logging.NewNopLogger())
return NewRedisQueryResults(pool, dupResults, slog.New(slog.DiscardHandler))
}

View file

@ -174,7 +174,7 @@ func newTestServiceWithConfig(t *testing.T, ds fleet.Datastore, fleetConfig conf
var eh *errorstore.Handler
if len(opts) > 0 {
if opts[0].Pool != nil {
eh = errorstore.NewHandler(ctx, opts[0].Pool, logger, time.Minute*10)
eh = errorstore.NewHandler(ctx, opts[0].Pool, logger.SlogLogger(), time.Minute*10)
ctx = ctxerr.NewContext(ctx, eh)
}
if opts[0].StartCronSchedules != nil {
@ -582,7 +582,7 @@ func RunServerForTestsWithServiceWithDS(t *testing.T, ctx context.Context, ds fl
rootMux.Handle("/enroll", ServeEndUserEnrollOTA(svc, "", ds, logger))
if len(opts) > 0 && opts[0].EnableSCIM {
require.NoError(t, scim.RegisterSCIM(rootMux, ds, svc, logger, &cfg))
require.NoError(t, scim.RegisterSCIM(rootMux, ds, svc, logger.SlogLogger(), &cfg))
rootMux.Handle("/api/v1/fleet/scim/details", apiHandler)
rootMux.Handle("/api/latest/fleet/scim/details", apiHandler)
}