Add SCIM Users (#27551)

For #27287

Video explaining the PR: https://www.youtube.com/watch?v=ZHgFUAvrPEI

This PR adds SCIM Users support for Okta. The goal is to first add
Users/Groups support so that the remaining backend SCIM work can be done
in parallel.

This PR does not include the following, which will be added in later PRs
- Changes file
- Groups support for Okta
- Full support for Entra ID
- Integration tests

# Checklist for submitter

- [x] If database migrations are included, checked table schema to
confirm autoupdate
- For database migrations:
- [x] Checked schema for all modified table for columns that will
auto-update timestamps during migration.
- [x] Confirmed that updating the timestamps is acceptable, and will not
cause unwanted side effects.
- [x] Ensured the correct collation is explicitly set for character
columns (`COLLATE utf8mb4_unicode_ci`).
- [x] Added/updated automated tests
- [x] A detailed QA plan exists on the associated ticket (if it isn't
there, work with the product group's QA engineer to add it)
- [x] Manual QA for all new/changed functionality
This commit is contained in:
Victor Lyuboslavsky 2025-04-01 11:02:24 -05:00 committed by GitHub
parent 94037e5e56
commit 2198fd8d65
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1923 additions and 4 deletions

View file

@ -22,6 +22,7 @@ import (
"github.com/WatchBeam/clock"
"github.com/e-dard/netbug"
"github.com/fleetdm/fleet/v4/ee/server/licensing"
"github.com/fleetdm/fleet/v4/ee/server/scim"
eeservice "github.com/fleetdm/fleet/v4/ee/server/service"
"github.com/fleetdm/fleet/v4/ee/server/service/digicert"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
@ -1171,11 +1172,14 @@ the way that the Fleet server works.
}
}
// SCEP proxy (for NDES, etc.)
if license.IsPremium() {
// SCEP proxy (for NDES, etc.)
if err = service.RegisterSCEPProxy(rootMux, ds, logger, nil); err != nil {
initFatal(err, "setup SCEP proxy")
}
if err = scim.RegisterSCIM(rootMux, ds, svc, logger); err != nil {
initFatal(err, "setup SCIM")
}
}
if config.Prometheus.BasicAuth.Username != "" && config.Prometheus.BasicAuth.Password != "" {

View file

@ -0,0 +1,102 @@
# SCIM (System for Cross-domain Identity Management) integration
## Reference docs
- [scim.cloud](https://scim.cloud/)
- [SCIM: Core Schema (RFC7643)](https://datatracker.ietf.org/doc/html/rfc7643)
- [SCIM: Protocol (RFC7644)](https://datatracker.ietf.org/doc/html/rfc7644)
- [scim Go library](https://github.com/elimity-com/scim)
## Okta integration
- https://developer.okta.com/docs/guides/scim-provisioning-integration-prepare/main/
### Testing Okta integration
First, create at least one SCIM user:
```
POST https://localhost:8080/api/latest/fleet/scim/Users
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"userName": "test.user@okta.local",
"name": {
"givenName": "Test",
"familyName": "User"
},
"emails": [{
"primary": true,
"value": "test.user@okta.local",
"type": "work"
}],
"active": true
}
```
Run test using [Runscope](https://www.runscope.com/). See [instructions](https://developer.okta.com/docs/guides/scim-provisioning-integration-prepare/main/#test-your-scim-api).
## Entra ID integration
- [SCIM guide](https://learn.microsoft.com/en-us/entra/identity/app-provisioning/use-scim-to-provision-users-and-groups)
- [SCIM validator](https://scimvalidator.microsoft.com/)
- Only test attributes that we implemented
### Testing Entra ID integration
Use [scimvalidator.microsoft.com](https://scimvalidator.microsoft.com/). Only test the attributes that we have implemented. To see our supported attributes, check the schema:
```
GET https://localhost:8080/api/latest/fleet/scim/Schemas
```
## Authentication
We use same authentication as API. HTTP header: `Authorization: Bearer xyz`
## Diagrams
```mermaid
---
title: Initial DB schema (not kept up to date)
---
erDiagram
HOST_SCIM_USER {
host_id uint PK
scim_user_id uint PK "FK"
}
SCIM_USERS {
id uint PK
external_id *string "Index"
user_name string "Unique"
given_name *string
family_name *string
active *bool
}
SCIM_USER_EMAILS {
id uint PK
scim_user_id uint FK
type *string "Index"
email string "Index"
primary *bool
}
SCIM_USER_GROUP {
scim_user_id string PK "FK"
group_id uint PK "FK"
}
SCIM_GROUPS {
id uint PK
external_id *string "Index"
display_name string "Index"
}
HOST_SCIM_USER }o--|| SCIM_USERS : "multiple hosts can have the same SCIM user"
SCIM_USERS ||--o{ SCIM_USER_GROUP: "zero-to-many"
SCIM_USER_GROUP }o--|| SCIM_GROUPS: "zero-to-many"
SCIM_USERS ||--o{ SCIM_USER_EMAILS: "zero-to-many"
COMMENT {
_ _ "created_at and updated_at columns not shown"
}
```
## Notes
- Okta and Entra ID do not support nested groups

180
ee/server/scim/scim.go Normal file
View file

@ -0,0 +1,180 @@
package scim
import (
"encoding/json"
"fmt"
"net/http"
"github.com/elimity-com/scim"
"github.com/elimity-com/scim/errors"
"github.com/elimity-com/scim/optional"
"github.com/elimity-com/scim/schema"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/service/middleware/auth"
"github.com/fleetdm/fleet/v4/server/service/middleware/log"
kitlog "github.com/go-kit/log"
"github.com/go-kit/log/level"
)
const (
maxResults = 1000
)
func RegisterSCIM(
mux *http.ServeMux,
ds fleet.Datastore,
svc fleet.Service,
logger kitlog.Logger,
) error {
config := scim.ServiceProviderConfig{
// TODO: DocumentationURI and Authentication scheme
DocumentationURI: optional.NewString("https://fleetdm.com/docs/get-started/why-fleet"),
SupportFiltering: true,
SupportPatch: true,
MaxResults: maxResults,
}
// The common attributes are id, externalId, and meta.
// In practice only meta.resourceType is required, while the other four (created, lastModified, location, and version) are not strictly required.
userSchema := schema.Schema{
ID: "urn:ietf:params:scim:schemas:core:2.0:User",
Name: optional.NewString("User"),
Description: optional.NewString("SCIM User"),
Attributes: []schema.CoreAttribute{
schema.SimpleCoreAttribute(schema.SimpleStringParams(schema.StringParams{
Name: "userName",
Required: true,
Uniqueness: schema.AttributeUniquenessServer(),
})),
schema.ComplexCoreAttribute(schema.ComplexParams{
Description: optional.NewString("The components of the user's real name. Providers MAY return just the full name as a single string in the formatted sub-attribute, or they MAY return just the individual component attributes using the other sub-attributes, or they MAY return both. If both variants are returned, they SHOULD be describing the same name, with the formatted name indicating how the component attributes should be combined."),
Name: "name",
SubAttributes: []schema.SimpleParams{
schema.SimpleStringParams(schema.StringParams{
Description: optional.NewString("The family name of the User, or last name in most Western languages (e.g., 'Jensen' given the full name 'Ms. Barbara J Jensen, III')."),
Name: "familyName",
}),
schema.SimpleStringParams(schema.StringParams{
Description: optional.NewString("The given name of the User, or first name in most Western languages (e.g., 'Barbara' given the full name 'Ms. Barbara J Jensen, III')."),
Name: "givenName",
}),
},
}),
schema.ComplexCoreAttribute(schema.ComplexParams{
Description: optional.NewString("Email addresses for the user. The value SHOULD be canonicalized by the service provider, e.g., 'bjensen@example.com' instead of 'bjensen@EXAMPLE.COM'. Canonical type values of 'work', 'home', and 'other'."),
MultiValued: true,
Name: "emails",
SubAttributes: []schema.SimpleParams{
schema.SimpleStringParams(schema.StringParams{
Description: optional.NewString("Email addresses for the user. The value SHOULD be canonicalized by the service provider, e.g., 'bjensen@example.com' instead of 'bjensen@EXAMPLE.COM'. Canonical type values of 'work', 'home', and 'other'."),
Name: "value",
}),
schema.SimpleStringParams(schema.StringParams{
CanonicalValues: []string{"work", "home", "other"},
Description: optional.NewString("A label indicating the attribute's function, e.g., 'work' or 'home'."),
Name: "type",
}),
schema.SimpleBooleanParams(schema.BooleanParams{
Description: optional.NewString("A Boolean value indicating the 'primary' or preferred attribute value for this attribute, e.g., the preferred mailing address or primary email address. The primary attribute value 'true' MUST appear no more than once."),
Name: "primary",
}),
},
}),
schema.SimpleCoreAttribute(schema.SimpleBooleanParams(schema.BooleanParams{
Description: optional.NewString("A Boolean value indicating the User's administrative status."),
Name: "active",
})),
},
}
scimLogger := kitlog.With(logger, "component", "SCIM")
resourceTypes := []scim.ResourceType{
{
ID: optional.NewString("User"),
Name: "User",
Endpoint: "/Users",
Description: optional.NewString("User Account"),
Schema: userSchema,
Handler: NewUserHandler(ds, scimLogger),
},
}
serverArgs := &scim.ServerArgs{
ServiceProviderConfig: &config,
ResourceTypes: resourceTypes,
}
serverOpts := []scim.ServerOption{
scim.WithLogger(&scimErrorLogger{Logger: scimLogger}),
}
server, err := scim.NewServer(serverArgs, serverOpts...)
if err != nil {
return err
}
scimErrorHandler := func(w http.ResponseWriter, detail string, status int) {
errorHandler(w, scimLogger, detail, status)
}
authorizer, err := authz.NewAuthorizer()
if err != nil {
return err
}
// TODO: Add APM/OpenTelemetry tracing and Prometheus middleware
applyMiddleware := func(prefix string, server http.Handler) http.Handler {
handler := http.StripPrefix(prefix, server)
handler = AuthorizationMiddleware(authorizer, scimLogger, handler)
handler = auth.AuthenticatedUserMiddleware(svc, scimErrorHandler, handler)
handler = log.LogResponseEndMiddleware(scimLogger, handler)
handler = auth.SetRequestsContextMiddleware(svc, handler)
return handler
}
mux.Handle("/api/v1/fleet/scim/", applyMiddleware("/api/v1/fleet/scim", server))
mux.Handle("/api/latest/fleet/scim/", applyMiddleware("/api/latest/fleet/scim", server))
return nil
}
func AuthorizationMiddleware(authorizer *authz.Authorizer, logger kitlog.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := authorizer.Authorize(r.Context(), &fleet.ScimUser{}, fleet.ActionWrite)
if err != nil {
errorHandler(w, logger, err.Error(), http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
func errorHandler(w http.ResponseWriter, logger kitlog.Logger, detail string, status int) {
scimErr := errors.ScimError{
Status: status,
Detail: detail,
}
raw, err := json.Marshal(scimErr)
if err != nil {
level.Error(logger).Log("msg", "failed marshaling scim error", "scimError", scimErr, "err", err)
return
}
w.Header().Set("Content-Type", "application/scim+json")
w.WriteHeader(scimErr.Status)
_, err = w.Write(raw)
if err != nil {
level.Error(logger).Log("msg", "failed writing response", "err", err)
}
}
type scimErrorLogger struct {
kitlog.Logger
}
var _ scim.Logger = &scimErrorLogger{}
func (l *scimErrorLogger) Error(args ...interface{}) {
level.Error(l.Logger).Log(
"error", fmt.Sprint(args...),
)
}

455
ee/server/scim/users.go Normal file
View file

@ -0,0 +1,455 @@
package scim
import (
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"unicode"
"github.com/elimity-com/scim"
"github.com/elimity-com/scim/errors"
"github.com/elimity-com/scim/optional"
"github.com/fleetdm/fleet/v4/server/fleet"
kitlog "github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/scim2/filter-parser/v2"
)
const (
// Common attributes: https://datatracker.ietf.org/doc/html/rfc7643#section-3.1
externalIdAttr = "externalId"
// User attributes: https://datatracker.ietf.org/doc/html/rfc7643#section-4.1
userNameAttr = "userName"
nameAttr = "name"
givenNameAttr = "givenName"
familyNameAttr = "familyName"
activeAttr = "active"
emailsAttr = "emails"
)
type UserHandler struct {
ds fleet.Datastore
logger kitlog.Logger
}
// Compile-time check
var _ scim.ResourceHandler = &UserHandler{}
func NewUserHandler(ds fleet.Datastore, logger kitlog.Logger) scim.ResourceHandler {
return &UserHandler{ds: ds, logger: logger}
}
func (u *UserHandler) Create(r *http.Request, attributes scim.ResourceAttributes) (scim.Resource, error) {
// Check for userName uniqueness
userName, err := getRequiredResource[string](attributes, userNameAttr)
if err != nil {
level.Error(u.logger).Log("msg", "failed to get userName", "err", err)
return scim.Resource{}, err
}
_, err = u.ds.ScimUserByUserName(r.Context(), userName)
if !fleet.IsNotFound(err) {
level.Info(u.logger).Log("msg", "user already exists", userNameAttr, userName)
return scim.Resource{}, errors.ScimErrorUniqueness
}
user, err := createUserFromAttributes(attributes)
if err != nil {
level.Error(u.logger).Log("msg", "failed to create user from attributes", userNameAttr, userName, "err", err)
return scim.Resource{}, err
}
user.ID, err = u.ds.CreateScimUser(r.Context(), user)
if err != nil {
return scim.Resource{}, err
}
return createUserResource(user), nil
}
func createUserFromAttributes(attributes scim.ResourceAttributes) (*fleet.ScimUser, error) {
user := fleet.ScimUser{}
var err error
user.UserName, err = getRequiredResource[string](attributes, userNameAttr)
if err != nil {
return nil, err
}
user.ExternalID, err = getOptionalResource[string](attributes, externalIdAttr)
if err != nil {
return nil, err
}
user.Active, err = getOptionalResource[bool](attributes, activeAttr)
if err != nil {
return nil, err
}
name, err := getComplexResource(attributes, nameAttr)
if err != nil {
return nil, err
}
user.FamilyName, err = getOptionalResource[string](name, familyNameAttr)
if err != nil {
return nil, err
}
user.GivenName, err = getOptionalResource[string](name, givenNameAttr)
if err != nil {
return nil, err
}
emails, err := getComplexResourceSlice(attributes, emailsAttr)
if err != nil {
return nil, err
}
userEmails := make([]fleet.ScimUserEmail, 0, len(emails))
for _, email := range emails {
userEmail := fleet.ScimUserEmail{}
userEmail.Email, err = getRequiredResource[string](email, "value")
if err != nil {
return nil, err
}
// Service providers SHOULD canonicalize the value according to [RFC5321]
// https://datatracker.ietf.org/doc/html/rfc7643#section-4.1.2
userEmail.Email, err = normalizeEmail(userEmail.Email)
if err != nil {
return nil, errors.ScimErrorBadParams([]string{"value"})
}
userEmail.Type, err = getOptionalResource[string](email, "type")
if err != nil {
return nil, err
}
userEmail.Primary, err = getOptionalResource[bool](email, "primary")
if err != nil {
return nil, err
}
userEmails = append(userEmails, userEmail)
}
user.Emails = userEmails
return &user, nil
}
func getRequiredResource[T string | bool](attributes scim.ResourceAttributes, key string) (T, error) {
var val T
valIntf, ok := attributes[key]
if !ok || valIntf == nil {
return val, errors.ScimErrorBadParams([]string{key})
}
val, ok = valIntf.(T)
if !ok {
return val, errors.ScimErrorBadParams([]string{key})
}
return val, nil
}
func getOptionalResource[T string | bool](attributes scim.ResourceAttributes, key string) (*T, error) {
var valPtr *T
valIntf, ok := attributes[key]
if ok && valIntf != nil {
val, ok := valIntf.(T)
if !ok {
return nil, errors.ScimErrorBadParams([]string{key})
}
valPtr = &val
}
return valPtr, nil
}
func getComplexResource(attributes scim.ResourceAttributes, key string) (map[string]interface{}, error) {
valIntf, ok := attributes[key]
if ok && valIntf != nil {
val, ok := valIntf.(map[string]interface{})
if !ok {
return nil, errors.ScimErrorBadParams([]string{key})
}
return val, nil
}
return nil, nil
}
func getComplexResourceSlice(attributes scim.ResourceAttributes, key string) ([]map[string]interface{}, error) {
valIntf, ok := attributes[key]
if ok && valIntf != nil {
valSliceIntf, ok := valIntf.([]interface{})
if !ok {
return nil, errors.ScimErrorBadParams([]string{key})
}
val := make([]map[string]interface{}, 0, len(valSliceIntf))
for _, v := range valSliceIntf {
valMap, ok := v.(map[string]interface{})
if !ok {
return nil, errors.ScimErrorBadParams([]string{key})
}
if len(valMap) > 0 {
val = append(val, valMap)
}
}
return val, nil
}
return nil, nil
}
func (u *UserHandler) Get(r *http.Request, id string) (scim.Resource, error) {
idUint, err := strconv.ParseUint(id, 10, 64)
if err != nil {
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
user, err := u.ds.ScimUserByID(r.Context(), uint(idUint))
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to get user", "id", id, "err", err)
return scim.Resource{}, err
}
return createUserResource(user), nil
}
func createUserResource(user *fleet.ScimUser) scim.Resource {
userResource := scim.Resource{}
userResource.ID = fmt.Sprintf("%d", user.ID)
if user.ExternalID != nil {
userResource.ExternalID = optional.NewString(*user.ExternalID)
}
userResource.Attributes = scim.ResourceAttributes{}
userResource.Attributes[userNameAttr] = user.UserName
if user.Active != nil {
userResource.Attributes[activeAttr] = *user.Active
}
if user.FamilyName != nil || user.GivenName != nil {
userResource.Attributes[nameAttr] = make(scim.ResourceAttributes)
if user.FamilyName != nil {
userResource.Attributes[nameAttr].(scim.ResourceAttributes)[familyNameAttr] = *user.FamilyName
}
if user.GivenName != nil {
userResource.Attributes[nameAttr].(scim.ResourceAttributes)[givenNameAttr] = *user.GivenName
}
}
if len(user.Emails) > 0 {
emails := make([]scim.ResourceAttributes, 0, len(user.Emails))
for _, email := range user.Emails {
emailResource := make(scim.ResourceAttributes)
emailResource["value"] = email.Email
if email.Type != nil {
emailResource["type"] = *email.Type
}
if email.Primary != nil {
emailResource["primary"] = *email.Primary
}
emails = append(emails, emailResource)
}
userResource.Attributes[emailsAttr] = emails
}
return userResource
}
// GetAll
// Per RFC7644 3.4.2, SHOULD ignore any query parameters they do not recognize instead of rejecting the query for versioning compatibility reasons
// https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.2
//
// Providers MUST decline to filter results if the specified filter operation is not recognized and return an HTTP 400 error with a
// "scimType" error of "invalidFilter" and an appropriate human-readable response as per Section 3.12. For example, if a client specified an
// unsupported operator named 'regex', the service provider should specify an error response description identifying the client error,
// e.g., 'The operator 'regex' is not supported.'
//
// If a SCIM service provider determines that too many results would be returned the server base URI, the server SHALL reject the request by
// returning an HTTP response with HTTP status code 400 (Bad Request) and JSON attribute "scimType" set to "tooMany" (see Table 9).
//
// totalResults: The total number of results returned by the list or query operation. The value may be larger than the number of
// resources returned, such as when returning a single page (see Section 3.4.2.4) of results where multiple pages are available.
func (u *UserHandler) GetAll(r *http.Request, params scim.ListRequestParams) (scim.Page, error) {
page := params.StartIndex
if page < 1 {
page = 1
}
count := params.Count
if count > maxResults {
return scim.Page{}, errors.ScimErrorTooMany
}
if count < 1 {
count = maxResults
}
opts := fleet.ScimUsersListOptions{
Page: uint(page), // nolint:gosec // ignore G115
PerPage: uint(count), // nolint:gosec // ignore G115
}
resourceFilter := r.URL.Query().Get("filter")
if resourceFilter != "" {
expr, err := filter.ParseAttrExp([]byte(resourceFilter))
if err != nil {
level.Error(u.logger).Log("msg", "failed to parse filter", "filter", resourceFilter, "err", err)
return scim.Page{}, errors.ScimErrorInvalidFilter
}
if !strings.EqualFold(expr.AttributePath.String(), "userName") || expr.Operator != "eq" {
level.Info(u.logger).Log("msg", "unsupported filter", "filter", resourceFilter)
return scim.Page{}, nil
}
userName, ok := expr.CompareValue.(string)
if !ok {
level.Error(u.logger).Log("msg", "unsupported value", "value", expr.CompareValue)
return scim.Page{}, nil
}
// Decode URL-encoded characters in userName, which is required to pass Microsoft Entra ID SCIM Validator
userName, err = url.QueryUnescape(userName)
if err != nil {
level.Error(u.logger).Log("msg", "failed to decode userName", "userName", userName, "err", err)
return scim.Page{}, nil
}
opts.UserNameFilter = &userName
}
users, totalResults, err := u.ds.ListScimUsers(r.Context(), opts)
if err != nil {
level.Error(u.logger).Log("msg", "failed to list users", "err", err)
return scim.Page{}, err
}
result := scim.Page{
TotalResults: int(totalResults), // nolint:gosec // ignore G115
Resources: make([]scim.Resource, 0, len(users)),
}
for _, user := range users {
result.Resources = append(result.Resources, createUserResource(&user))
}
return result, nil
}
func (u *UserHandler) Replace(r *http.Request, id string, attributes scim.ResourceAttributes) (scim.Resource, error) {
idUint, err := strconv.ParseUint(id, 10, 64)
if err != nil {
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
user, err := createUserFromAttributes(attributes)
if err != nil {
level.Error(u.logger).Log("msg", "failed to create user from attributes", "id", id, "err", err)
return scim.Resource{}, err
}
user.ID = uint(idUint)
err = u.ds.ReplaceScimUser(r.Context(), user)
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user to replace", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to replace user", "id", id, "err", err)
return scim.Resource{}, err
}
return createUserResource(user), nil
}
// Delete
// https://datatracker.ietf.org/doc/html/rfc7644#section-3.6
// MUST return a 404 (Not Found) error code for all operations associated with the previously deleted resource
func (u *UserHandler) Delete(r *http.Request, id string) error {
idUint, err := strconv.ParseUint(id, 10, 64)
if err != nil {
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
return errors.ScimErrorResourceNotFound(id)
}
err = u.ds.DeleteScimUser(r.Context(), uint(idUint))
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user to delete", "id", id)
return errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to delete user", "id", id, "err", err)
return err
}
return nil
}
// Patch
// Okta only requires patching the "active" attribute:
// https://developer.okta.com/docs/api/openapi/okta-scim/guides/scim-20/#update-a-specific-user-patch
func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchOperation) (scim.Resource, error) {
idUint, err := strconv.ParseUint(id, 10, 64)
if err != nil {
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
}
user, err := u.ds.ScimUserByID(r.Context(), uint(idUint))
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user to patch", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to get user to patch", "id", id, "err", err)
return scim.Resource{}, err
}
for _, op := range operations {
if op.Op != "replace" {
level.Info(u.logger).Log("msg", "unsupported patch operation", "op", op.Op)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
switch {
case op.Path == nil:
newValues, ok := op.Value.(map[string]interface{})
if !ok {
level.Info(u.logger).Log("msg", "unsupported patch value", "value", op.Value)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
if len(newValues) != 1 {
level.Info(u.logger).Log("msg", "too many patch values", "value", op.Value)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
active, err := getRequiredResource[bool](newValues, activeAttr)
if err != nil {
level.Info(u.logger).Log("msg", "failed to get active value", "value", op.Value)
return scim.Resource{}, err
}
user.Active = &active
case op.Path.String() == activeAttr:
active, ok := op.Value.(bool)
if !ok {
level.Error(u.logger).Log("msg", "unsupported 'active' patch value", "value", op.Value)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
user.Active = &active
default:
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
}
}
err = u.ds.ReplaceScimUser(r.Context(), user)
switch {
case fleet.IsNotFound(err):
level.Info(u.logger).Log("msg", "failed to find user to patch", "id", id)
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
case err != nil:
level.Error(u.logger).Log("msg", "failed to patch user", "id", id, "err", err)
return scim.Resource{}, err
}
return createUserResource(user), nil
}
// normalizeEmail
// The local-part of a mailbox MUST BE treated as case sensitive.
// Mailbox domains follow normal DNS rules and are hence not case sensitive.
// https://datatracker.ietf.org/doc/html/rfc5321#section-2.4
func normalizeEmail(email string) (string, error) {
email = removeWhitespace(email)
emailParts := strings.SplitN(email, "@", 2)
if len(emailParts) != 2 {
return "", fmt.Errorf("invalid email %s", email)
}
emailParts[1] = strings.ToLower(emailParts[1])
return strings.Join(emailParts, "@"), nil
}
func removeWhitespace(str string) string {
return strings.Map(func(r rune) rune {
if unicode.IsSpace(r) {
return -1
}
return r
}, str)
}

4
go.mod
View file

@ -181,12 +181,15 @@ require (
github.com/cyphar/filepath-securejoin v0.2.5 // indirect
github.com/dgraph-io/ristretto v0.1.0 // indirect
github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 // indirect
github.com/di-wu/parser v0.2.2 // indirect
github.com/di-wu/xsd-datetime v1.0.0 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/go-connections v0.4.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/edsrzf/mmap-go v1.1.0 // indirect
github.com/elastic/go-sysinfo v1.11.2 // indirect
github.com/elastic/go-windows v1.0.1 // indirect
github.com/elimity-com/scim v0.0.0-20240320110924-172bf2aee9c8 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/fatih/structs v1.1.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
@ -249,6 +252,7 @@ require (
github.com/prometheus/procfs v0.12.0 // indirect
github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/scim2/filter-parser/v2 v2.2.0 // indirect
github.com/secDre4mer/pkcs7 v0.0.0-20240322103146-665324a4461d // indirect
github.com/secure-systems-lab/go-securesystemslib v0.5.0 // indirect
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect

8
go.sum
View file

@ -219,6 +219,10 @@ github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUn
github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g=
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
github.com/di-wu/parser v0.2.2 h1:I9oHJ8spBXOeL7Wps0ffkFFFiXJf/pk7NX9lcAMqRMU=
github.com/di-wu/parser v0.2.2/go.mod h1:SLp58pW6WamdmznrVRrw2NTyn4wAvT9rrEFynKX7nYo=
github.com/di-wu/xsd-datetime v1.0.0 h1:vZoGNkbzpBNoc+JyfVLEbutNDNydYV8XwHeV7eUJoxI=
github.com/di-wu/xsd-datetime v1.0.0/go.mod h1:i3iEhrP3WchwseOBeIdW/zxeoleXTOzx1WyDXgdmOww=
github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e h1:vUmf0yezR0y7jJ5pceLHthLaYf4bA5T14B6q39S4q2Q=
github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e/go.mod h1:YTIHhz/QFSYnu/EhlF2SpU2Uk+32abacUYA5ZPljz1A=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
@ -246,6 +250,8 @@ github.com/elazarl/go-bindata-assetfs v1.0.1 h1:m0kkaHRKEu7tUIUFVwhGGGYClXvyl4RE
github.com/elazarl/go-bindata-assetfs v1.0.1/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4=
github.com/elazarl/goproxy v1.2.1 h1:njjgvO6cRG9rIqN2ebkqy6cQz2Njkx7Fsfv/zIZqgug=
github.com/elazarl/goproxy v1.2.1/go.mod h1:YfEbZtqP4AetfO6d40vWchF3znWX7C7Vd6ZMfdL8z64=
github.com/elimity-com/scim v0.0.0-20240320110924-172bf2aee9c8 h1:0+BTyxIYgiVAry/P5s8R4dYuLkhB9Nhso8ogFWNr4IQ=
github.com/elimity-com/scim v0.0.0-20240320110924-172bf2aee9c8/go.mod h1:JkjcmqbLW+khwt2fmBPJFBhx2zGZ8XobRZ+O0VhlwWo=
github.com/emirpasic/gods v1.12.0/go.mod h1:YfzfFFoVP/catgzJb4IKIqXjX78Ha8FMSDh3ymbK86o=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
@ -745,6 +751,8 @@ github.com/saferwall/pe v1.5.5 h1:GGbzKjXDm7i+1K6riOgtgblyTdRmTbr3r11IzjovAK8=
github.com/saferwall/pe v1.5.5/go.mod h1:mJx+PuptmNpoPFBNhWs/uDMFL/kTHVZIkg0d4OUJFbQ=
github.com/sassoftware/relic/v8 v8.0.1 h1:uYUoaoTQMs67up8/46NgrSxSftgfY4VWBusDVg56k7I=
github.com/sassoftware/relic/v8 v8.0.1/go.mod h1:s/MwugRcovgYcNJNOyvLfqRHDX7iArHtFtUR9kEodz8=
github.com/scim2/filter-parser/v2 v2.2.0 h1:QGadEcsmypxg8gYChRSM2j1edLyE/2j72j+hdmI4BJM=
github.com/scim2/filter-parser/v2 v2.2.0/go.mod h1:jWnkDToqX/Y0ugz0P5VvpVEUKcWcyHHj+X+je9ce5JA=
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 h1:rc/CcqLH3lh8n+csdOuDfP+NuykE0U6AeYSJJHKDgSg=
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9/go.mod h1:a/83NAfUXvEuLpmxDssAXxgUgrEy12MId3Wd7OTs76s=
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=

View file

@ -1046,3 +1046,13 @@ allow {
subject.global_role == admin
action == [read, write][_]
}
##
# SCIM (System for Cross-domain Identity Management)
##
# Global admins and maintainers can access SCIM.
allow {
object.type == "scim_user"
subject.global_role == [admin, maintainer][_]
action == write
}

View file

@ -0,0 +1,78 @@
package tables
import (
"database/sql"
"fmt"
)
func init() {
MigrationClient.AddMigration(Up_20250331042354, Down_20250331042354)
}
func Up_20250331042354(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TABLE IF NOT EXISTS scim_users (
id int UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
external_id VARCHAR(255) NULL,
user_name VARCHAR(255) NOT NULL,
given_name VARCHAR(255) NULL,
family_name VARCHAR(255) NULL,
active TINYINT(1) NULL,
created_at DATETIME(6) NOT NULL DEFAULT NOW(6),
updated_at DATETIME(6) NOT NULL DEFAULT NOW(6) ON UPDATE NOW(6),
UNIQUE KEY idx_scim_users_user_name (user_name),
KEY idx_scim_users_external_id (external_id)
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci;
CREATE TABLE IF NOT EXISTS host_scim_user (
host_id INT UNSIGNED NOT NULL,
scim_user_id INT UNSIGNED NOT NULL,
created_at DATETIME(6) NOT NULL DEFAULT NOW(6),
PRIMARY KEY (host_id, scim_user_id),
CONSTRAINT fk_host_scim_scim_user_id FOREIGN KEY (scim_user_id) REFERENCES scim_users (id) ON DELETE CASCADE
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci;
CREATE TABLE if NOT EXISTS scim_user_emails (
-- Using BIGINT because we clear and repopulate the emails frequently (during user update)
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
scim_user_id INT UNSIGNED NOT NULL,
email VARCHAR(255) NOT NULL,
` + "`primary`" + ` TINYINT(1) NULL,
type VARCHAR(31) NULL,
created_at DATETIME(6) NOT NULL DEFAULT NOW(6),
updated_at DATETIME(6) NOT NULL DEFAULT NOW(6) ON UPDATE NOW(6),
KEY idx_scim_user_emails_email_type(type, email),
CONSTRAINT fk_scim_user_emails_scim_user_id FOREIGN KEY (scim_user_id) REFERENCES scim_users (id) ON DELETE CASCADE
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci;
CREATE TABLE IF NOT EXISTS scim_groups (
id int UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
external_id VARCHAR(255) NULL,
display_name VARCHAR(255) NOT NULL,
created_at DATETIME(6) NOT NULL DEFAULT NOW(6),
updated_at DATETIME(6) NOT NULL DEFAULT NOW(6) ON UPDATE NOW(6),
KEY idx_scim_groups_external_id (external_id),
-- Entra ID requires a unique display name
UNIQUE KEY idx_scim_groups_display_name (display_name)
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci;
CREATE TABLE IF NOT EXISTS scim_user_group (
scim_user_id INT UNSIGNED NOT NULL,
group_id INT UNSIGNED NOT NULL,
created_at DATETIME(6) NOT NULL DEFAULT NOW(6),
PRIMARY KEY (scim_user_id, group_id),
CONSTRAINT fk_scim_user_group_scim_user_id FOREIGN KEY (scim_user_id) REFERENCES scim_users (id) ON DELETE CASCADE,
CONSTRAINT fk_scim_user_group_group_id FOREIGN KEY (group_id) REFERENCES scim_groups (id) ON DELETE CASCADE
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci;
`)
if err != nil {
return fmt.Errorf("failed to create scim tables: %s", err)
}
return nil
}
func Down_20250331042354(tx *sql.Tx) error {
return nil
}

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,331 @@
package mysql
import (
"context"
"database/sql"
"errors"
"strings"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/jmoiron/sqlx"
)
// CreateScimUser creates a new SCIM user in the database
func (ds *Datastore) CreateScimUser(ctx context.Context, user *fleet.ScimUser) (uint, error) {
var userID uint
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
const insertUserQuery = `
INSERT INTO scim_users (
external_id, user_name, given_name, family_name, active
) VALUES (?, ?, ?, ?, ?)`
result, err := tx.ExecContext(
ctx,
insertUserQuery,
user.ExternalID,
user.UserName,
user.GivenName,
user.FamilyName,
user.Active,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "insert scim user")
}
id, err := result.LastInsertId()
if err != nil {
return ctxerr.Wrap(ctx, err, "insert scim user last insert id")
}
user.ID = uint(id) // nolint:gosec // dismiss G115
userID = user.ID
return insertEmails(ctx, tx, user)
})
return userID, err
}
// ScimUserByID retrieves a SCIM user by ID
func (ds *Datastore) ScimUserByID(ctx context.Context, id uint) (*fleet.ScimUser, error) {
const query = `
SELECT
id, external_id, user_name, given_name, family_name, active
FROM scim_users
WHERE id = ?
`
user := &fleet.ScimUser{}
err := sqlx.GetContext(ctx, ds.reader(ctx), user, query, id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, notFound("scim user").WithID(id)
}
return nil, ctxerr.Wrap(ctx, err, "select scim user")
}
// Get the user's emails
emails, err := ds.getScimUserEmails(ctx, id)
if err != nil {
return nil, err
}
user.Emails = emails
return user, nil
}
// ScimUserByUserName retrieves a SCIM user by username
func (ds *Datastore) ScimUserByUserName(ctx context.Context, userName string) (*fleet.ScimUser, error) {
const query = `
SELECT
id, external_id, user_name, given_name, family_name, active
FROM scim_users
WHERE user_name = ?
`
user := &fleet.ScimUser{}
err := sqlx.GetContext(ctx, ds.reader(ctx), user, query, userName)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, notFound("scim user")
}
return nil, ctxerr.Wrap(ctx, err, "select scim user by userName")
}
// Get the user's emails
emails, err := ds.getScimUserEmails(ctx, user.ID)
if err != nil {
return nil, err
}
user.Emails = emails
return user, nil
}
// ReplaceScimUser replaces an existing SCIM user in the database
func (ds *Datastore) ReplaceScimUser(ctx context.Context, user *fleet.ScimUser) error {
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
// Update the SCIM user
const updateUserQuery = `
UPDATE scim_users SET
external_id = ?,
user_name = ?,
given_name = ?,
family_name = ?,
active = ?
WHERE id = ?`
result, err := tx.ExecContext(
ctx,
updateUserQuery,
user.ExternalID,
user.UserName,
user.GivenName,
user.FamilyName,
user.Active,
user.ID,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "update scim user")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return ctxerr.Wrap(ctx, err, "get rows affected for update scim user")
}
if rowsAffected == 0 {
return notFound("scim user").WithID(user.ID)
}
// We assume that email is not blank/null.
// However, we do not assume that email/type are unique for a user. To keep the code simple, we:
// 1. Delete all existing emails
// 2. Insert all new emails
// This is less efficient and can be optimized if we notice a load on these tables in production.
const deleteEmailsQuery = `DELETE FROM scim_user_emails WHERE scim_user_id = ?`
_, err = tx.ExecContext(ctx, deleteEmailsQuery, user.ID)
if err != nil {
return ctxerr.Wrap(ctx, err, "delete scim user emails")
}
return insertEmails(ctx, tx, user)
})
}
func insertEmails(ctx context.Context, tx sqlx.ExtContext, user *fleet.ScimUser) error {
// Insert the user's emails in a single batch if any
if len(user.Emails) > 0 {
// Build the batch insert query
valueStrings := make([]string, 0, len(user.Emails))
valueArgs := make([]interface{}, 0, len(user.Emails)*4)
for i := range user.Emails {
user.Emails[i].ScimUserID = user.ID
valueStrings = append(valueStrings, "(?, ?, ?, ?)")
valueArgs = append(valueArgs,
user.Emails[i].ScimUserID,
user.Emails[i].Email,
user.Emails[i].Primary,
user.Emails[i].Type,
)
}
// Construct the batch insert query
insertEmailQuery := `
INSERT INTO scim_user_emails (
scim_user_id, email, ` + "`primary`" + `, type
) VALUES ` + strings.Join(valueStrings, ",")
// Execute the batch insert
_, err := tx.ExecContext(ctx, insertEmailQuery, valueArgs...)
if err != nil {
return ctxerr.Wrap(ctx, err, "batch insert scim user emails")
}
}
return nil
}
// DeleteScimUser deletes a SCIM user from the database
func (ds *Datastore) DeleteScimUser(ctx context.Context, id uint) error {
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
// Delete all email entries for the user
const deleteEmailsQuery = `DELETE FROM scim_user_emails WHERE scim_user_id = ?`
_, err := tx.ExecContext(ctx, deleteEmailsQuery, id)
if err != nil {
return ctxerr.Wrap(ctx, err, "delete scim user emails")
}
// Delete the user
const deleteUserQuery = `DELETE FROM scim_users WHERE id = ?`
result, err := tx.ExecContext(ctx, deleteUserQuery, id)
if err != nil {
return ctxerr.Wrap(ctx, err, "delete scim user")
}
// Check if the user existed
rowsAffected, err := result.RowsAffected()
if err != nil {
return ctxerr.Wrap(ctx, err, "get rows affected for delete scim user")
}
if rowsAffected == 0 {
return notFound("scim user").WithID(id)
}
return nil
})
}
// ListScimUsers retrieves a list of SCIM users with optional filtering
func (ds *Datastore) ListScimUsers(ctx context.Context, opts fleet.ScimUsersListOptions) (users []fleet.ScimUser, totalResults uint, err error) {
// Default pagination values if not provided
if opts.Page == 0 {
opts.Page = 1
}
if opts.PerPage == 0 {
opts.PerPage = 100
}
// Calculate offset for pagination
offset := (opts.Page - 1) * opts.PerPage
// Build the base query
baseQuery := `
SELECT DISTINCT
scim_users.id, external_id, user_name, given_name, family_name, active
FROM scim_users
`
// Add joins and where clauses based on filters
var whereClause string
var params []interface{}
if opts.UserNameFilter != nil {
// Filter by username
whereClause = " WHERE scim_users.user_name = ?"
params = append(params, *opts.UserNameFilter)
} else if opts.EmailTypeFilter != nil && opts.EmailValueFilter != nil {
// Filter by email type and value
baseQuery += " LEFT JOIN scim_user_emails ON scim_users.id = scim_user_emails.scim_user_id"
whereClause = " WHERE scim_user_emails.type = ? AND scim_user_emails.email = ?"
params = append(params, *opts.EmailTypeFilter, *opts.EmailValueFilter)
}
// First, get the total count without pagination
countQuery := "SELECT COUNT(DISTINCT id) FROM (" + baseQuery + whereClause + ") AS filtered_users"
err = sqlx.GetContext(ctx, ds.reader(ctx), &totalResults, countQuery, params...)
if err != nil {
return nil, 0, ctxerr.Wrap(ctx, err, "count total scim users")
}
// Add pagination to the main query
query := baseQuery + whereClause + " ORDER BY scim_users.id LIMIT ? OFFSET ?"
params = append(params, opts.PerPage, offset)
// Execute the query
err = sqlx.SelectContext(ctx, ds.reader(ctx), &users, query, params...)
if err != nil {
return nil, 0, ctxerr.Wrap(ctx, err, "list scim users")
}
// Process the results
userIDs := make([]uint, 0, len(users))
userMap := make(map[uint]*fleet.ScimUser, len(users))
for i, user := range users {
userIDs = append(userIDs, user.ID)
userMap[user.ID] = &users[i]
}
// If no users found, return empty slice
if len(users) == 0 {
return users, totalResults, nil
}
// Fetch emails for all users in a single query
emailQuery, args, err := sqlx.In(`
SELECT
scim_user_id, email, `+"`primary`"+`, type
FROM scim_user_emails
WHERE scim_user_id IN (?)
ORDER BY email ASC
`, userIDs)
if err != nil {
return nil, 0, ctxerr.Wrap(ctx, err, "prepare emails query")
}
// Convert query for the specific DB dialect
emailQuery = ds.reader(ctx).Rebind(emailQuery)
// Execute the email query
var allEmails []fleet.ScimUserEmail
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &allEmails, emailQuery, args...); err != nil {
if !errors.Is(err, sql.ErrNoRows) {
return nil, 0, ctxerr.Wrap(ctx, err, "select scim user emails")
}
}
// Associate emails with their users
for i := range allEmails {
email := allEmails[i]
if user, ok := userMap[email.ScimUserID]; ok {
user.Emails = append(user.Emails, email)
}
}
return users, totalResults, nil
}
// getScimUserEmails retrieves all emails for a SCIM user
func (ds *Datastore) getScimUserEmails(ctx context.Context, userID uint) ([]fleet.ScimUserEmail, error) {
const query = `
SELECT
scim_user_id, email, ` + "`primary`" + `, type
FROM scim_user_emails
WHERE scim_user_id = ? ORDER BY email ASC
`
var emails []fleet.ScimUserEmail
err := sqlx.SelectContext(ctx, ds.reader(ctx), &emails, query, userID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, ctxerr.Wrap(ctx, err, "select scim user emails")
}
return emails, nil
}

View file

@ -0,0 +1,493 @@
package mysql
import (
"context"
"testing"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestScim(t *testing.T) {
ds := CreateMySQLDS(t)
cases := []struct {
name string
fn func(t *testing.T, ds *Datastore)
}{
{"ScimUserCreate", testScimUserCreate},
{"ScimUserByID", testScimUserByID},
{"ScimUserByUserName", testScimUserByUserName},
{"ReplaceScimUser", testReplaceScimUser},
{"DeleteScimUser", testDeleteScimUser},
{"ListScimUsers", testListScimUsers},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
defer TruncateTables(t, ds, "scim_users", "scim_user_emails")
c.fn(t, ds)
})
}
}
func testScimUserCreate(t *testing.T, ds *Datastore) {
usersToCreate := []fleet.ScimUser{
{
UserName: "user1",
ExternalID: nil,
GivenName: nil,
FamilyName: nil,
Active: nil,
Emails: []fleet.ScimUserEmail{},
},
{
UserName: "user2",
ExternalID: ptr.String("ext-123"),
GivenName: ptr.String("John"),
FamilyName: ptr.String("Doe"),
Active: ptr.Bool(true),
Emails: []fleet.ScimUserEmail{
{
Email: "john.doe@example.com",
Primary: ptr.Bool(true),
Type: ptr.String("work"),
},
},
},
{
UserName: "user3",
ExternalID: ptr.String("ext-456"),
GivenName: ptr.String("Jane"),
FamilyName: ptr.String("Smith"),
Active: ptr.Bool(true),
Emails: []fleet.ScimUserEmail{
{
Email: "jane.personal@example.com",
Primary: ptr.Bool(false),
Type: ptr.String("home"),
},
{
Email: "jane.smith@example.com",
Primary: ptr.Bool(true),
Type: ptr.String("work"),
},
},
},
}
for _, u := range usersToCreate {
var err error
userCopy := u
userCopy.ID, err = ds.CreateScimUser(context.Background(), &u)
assert.Nil(t, err)
verify, err := ds.ScimUserByUserName(context.Background(), u.UserName)
assert.Nil(t, err)
assert.Equal(t, userCopy.ID, verify.ID)
assert.Equal(t, userCopy.UserName, verify.UserName)
assert.Equal(t, userCopy.ExternalID, verify.ExternalID)
assert.Equal(t, userCopy.GivenName, verify.GivenName)
assert.Equal(t, userCopy.FamilyName, verify.FamilyName)
assert.Equal(t, userCopy.Active, verify.Active)
// Verify emails
assert.Equal(t, len(userCopy.Emails), len(verify.Emails))
for i, email := range userCopy.Emails {
assert.Equal(t, email.Email, verify.Emails[i].Email)
assert.Equal(t, email.Primary, verify.Emails[i].Primary)
assert.Equal(t, email.Type, verify.Emails[i].Type)
assert.Equal(t, u.ID, verify.Emails[i].ScimUserID)
}
}
}
func testScimUserByID(t *testing.T, ds *Datastore) {
users := createTestScimUsers(t, ds)
for _, tt := range users {
returned, err := ds.ScimUserByID(context.Background(), tt.ID)
assert.Nil(t, err)
assert.Equal(t, tt.ID, returned.ID)
assert.Equal(t, tt.UserName, returned.UserName)
assert.Equal(t, tt.ExternalID, returned.ExternalID)
assert.Equal(t, tt.GivenName, returned.GivenName)
assert.Equal(t, tt.FamilyName, returned.FamilyName)
assert.Equal(t, tt.Active, returned.Active)
// Verify emails
assert.Equal(t, len(tt.Emails), len(returned.Emails))
for i, email := range tt.Emails {
assert.Equal(t, email.Email, returned.Emails[i].Email)
assert.Equal(t, email.Primary, returned.Emails[i].Primary)
assert.Equal(t, email.Type, returned.Emails[i].Type)
assert.Equal(t, tt.ID, returned.Emails[i].ScimUserID)
}
}
// test missing user
_, err := ds.ScimUserByID(context.Background(), 10000000000)
assert.True(t, fleet.IsNotFound(err))
}
func testScimUserByUserName(t *testing.T, ds *Datastore) {
users := createTestScimUsers(t, ds)
for _, tt := range users {
returned, err := ds.ScimUserByUserName(context.Background(), tt.UserName)
assert.Nil(t, err)
assert.Equal(t, tt.ID, returned.ID)
assert.Equal(t, tt.UserName, returned.UserName)
assert.Equal(t, tt.ExternalID, returned.ExternalID)
assert.Equal(t, tt.GivenName, returned.GivenName)
assert.Equal(t, tt.FamilyName, returned.FamilyName)
assert.Equal(t, tt.Active, returned.Active)
// Verify emails
assert.Equal(t, len(tt.Emails), len(returned.Emails))
for i, email := range tt.Emails {
assert.Equal(t, email.Email, returned.Emails[i].Email)
assert.Equal(t, email.Primary, returned.Emails[i].Primary)
assert.Equal(t, email.Type, returned.Emails[i].Type)
assert.Equal(t, tt.ID, returned.Emails[i].ScimUserID)
}
}
// test missing user
_, err := ds.ScimUserByUserName(context.Background(), "nonexistent-user")
assert.NotNil(t, err)
}
func createTestScimUsers(t *testing.T, ds *Datastore) []*fleet.ScimUser {
createUsers := []fleet.ScimUser{
{
UserName: "test-user1",
ExternalID: ptr.String("ext-test-123"),
GivenName: ptr.String("Test"),
FamilyName: ptr.String("User"),
Active: ptr.Bool(true),
Emails: []fleet.ScimUserEmail{
{
Email: "test.user@example.com",
Primary: ptr.Bool(true),
Type: ptr.String("work"),
},
},
},
{
UserName: "test-user2",
ExternalID: ptr.String("ext-test-456"),
GivenName: ptr.String("Another"),
FamilyName: ptr.String("User"),
Active: ptr.Bool(true),
Emails: []fleet.ScimUserEmail{
{
Email: "another.personal@example.com",
Primary: ptr.Bool(false),
Type: ptr.String("home"),
},
{
Email: "another.user@example.com",
Primary: ptr.Bool(true),
Type: ptr.String("work"),
},
},
},
}
var users []*fleet.ScimUser
for _, u := range createUsers {
var err error
u.ID, err = ds.CreateScimUser(context.Background(), &u)
require.Nil(t, err)
users = append(users, &u)
}
return users
}
func testReplaceScimUser(t *testing.T, ds *Datastore) {
// Create a test user
user := fleet.ScimUser{
UserName: "replace-test-user",
ExternalID: ptr.String("ext-replace-123"),
GivenName: ptr.String("Original"),
FamilyName: ptr.String("User"),
Active: ptr.Bool(true),
Emails: []fleet.ScimUserEmail{
{
Email: "original.user@example.com",
Primary: ptr.Bool(true),
Type: ptr.String("work"),
},
},
}
var err error
user.ID, err = ds.CreateScimUser(context.Background(), &user)
require.Nil(t, err)
// Verify the user was created correctly
createdUser, err := ds.ScimUserByID(context.Background(), user.ID)
require.Nil(t, err)
assert.Equal(t, user.UserName, createdUser.UserName)
assert.Equal(t, user.ExternalID, createdUser.ExternalID)
assert.Equal(t, user.GivenName, createdUser.GivenName)
assert.Equal(t, user.FamilyName, createdUser.FamilyName)
assert.Equal(t, user.Active, createdUser.Active)
assert.Equal(t, 1, len(createdUser.Emails))
assert.Equal(t, "original.user@example.com", createdUser.Emails[0].Email)
// Modify the user
updatedUser := fleet.ScimUser{
ID: user.ID,
UserName: "replace-test-user", // Same username
ExternalID: ptr.String("ext-replace-456"), // Changed external ID
GivenName: ptr.String("Updated"), // Changed given name
FamilyName: ptr.String("User"), // Same family name
Active: ptr.Bool(false), // Changed active status
Emails: []fleet.ScimUserEmail{ // Changed emails
{
Email: "updated.user@example.com",
Primary: ptr.Bool(true),
Type: ptr.String("work"),
},
{
Email: "personal.user@example.com",
Primary: ptr.Bool(false),
Type: ptr.String("home"),
},
},
}
// Replace the user
err = ds.ReplaceScimUser(context.Background(), &updatedUser)
require.Nil(t, err)
// Verify the user was updated correctly
replacedUser, err := ds.ScimUserByID(context.Background(), user.ID)
require.Nil(t, err)
assert.Equal(t, updatedUser.UserName, replacedUser.UserName)
assert.Equal(t, updatedUser.ExternalID, replacedUser.ExternalID)
assert.Equal(t, updatedUser.GivenName, replacedUser.GivenName)
assert.Equal(t, updatedUser.FamilyName, replacedUser.FamilyName)
assert.Equal(t, updatedUser.Active, replacedUser.Active)
// Verify emails were replaced
assert.Equal(t, 2, len(replacedUser.Emails))
assert.Equal(t, "personal.user@example.com", replacedUser.Emails[0].Email) // Alphabetical order
assert.Equal(t, "updated.user@example.com", replacedUser.Emails[1].Email)
// Test replacing a non-existent user
nonExistentUser := fleet.ScimUser{
ID: 99999, // Non-existent ID
UserName: "non-existent",
ExternalID: ptr.String("ext-non-existent"),
GivenName: ptr.String("Non"),
FamilyName: ptr.String("Existent"),
Active: ptr.Bool(true),
}
err = ds.ReplaceScimUser(context.Background(), &nonExistentUser)
assert.True(t, fleet.IsNotFound(err))
}
func testDeleteScimUser(t *testing.T, ds *Datastore) {
// Create a test user
user := fleet.ScimUser{
UserName: "delete-test-user",
ExternalID: ptr.String("ext-delete-123"),
GivenName: ptr.String("Delete"),
FamilyName: ptr.String("User"),
Active: ptr.Bool(true),
Emails: []fleet.ScimUserEmail{
{
Email: "delete.user@example.com",
Primary: ptr.Bool(true),
Type: ptr.String("work"),
},
},
}
var err error
user.ID, err = ds.CreateScimUser(context.Background(), &user)
require.Nil(t, err)
// Verify the user was created correctly
createdUser, err := ds.ScimUserByID(context.Background(), user.ID)
require.Nil(t, err)
assert.Equal(t, user.UserName, createdUser.UserName)
// Delete the user
err = ds.DeleteScimUser(context.Background(), user.ID)
require.Nil(t, err)
// Verify the user was deleted
_, err = ds.ScimUserByID(context.Background(), user.ID)
assert.True(t, fleet.IsNotFound(err))
// Test deleting a non-existent user
err = ds.DeleteScimUser(context.Background(), 99999) // Non-existent ID
assert.True(t, fleet.IsNotFound(err))
}
func testListScimUsers(t *testing.T, ds *Datastore) {
// Create test users with different attributes and emails
users := []fleet.ScimUser{
{
UserName: "list-test-user1",
ExternalID: ptr.String("ext-list-123"),
GivenName: ptr.String("List"),
FamilyName: ptr.String("User1"),
Active: ptr.Bool(true),
Emails: []fleet.ScimUserEmail{
{
Email: "list.user1@example.com",
Primary: ptr.Bool(true),
Type: ptr.String("work"),
},
},
},
{
UserName: "list-test-user2",
ExternalID: ptr.String("ext-list-456"),
GivenName: ptr.String("List"),
FamilyName: ptr.String("User2"),
Active: ptr.Bool(true),
Emails: []fleet.ScimUserEmail{
{
Email: "list.user2@example.com",
Primary: ptr.Bool(true),
Type: ptr.String("work"),
},
{
Email: "personal.user2@example.com",
Primary: ptr.Bool(false),
Type: ptr.String("home"),
},
},
},
{
UserName: "different-user3",
ExternalID: ptr.String("ext-list-789"),
GivenName: ptr.String("Different"),
FamilyName: ptr.String("User3"),
Active: ptr.Bool(false),
Emails: []fleet.ScimUserEmail{
{
Email: "different.user3@example.com",
Primary: ptr.Bool(true),
Type: ptr.String("work"),
},
},
},
}
// Create the users
for i := range users {
var err error
users[i].ID, err = ds.CreateScimUser(context.Background(), &users[i])
require.Nil(t, err)
}
// Test 1: List all users without filters
allUsers, totalResults, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
Page: 1,
PerPage: 10,
})
require.Nil(t, err)
assert.Equal(t, 3, len(allUsers))
assert.Equal(t, uint(3), totalResults)
// Verify that our test users are in the results
foundUsers := 0
for _, u := range allUsers {
for _, testUser := range users {
if u.ID == testUser.ID {
foundUsers++
break
}
}
}
assert.Equal(t, 3, foundUsers)
// Test 2: Pagination - first page with 2 items
page1Users, totalPage1, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
Page: 1,
PerPage: 2,
})
require.Nil(t, err)
assert.Equal(t, 2, len(page1Users))
assert.Equal(t, uint(3), totalPage1) // Total should still be 3
// Test 3: Pagination - second page with 2 items
page2Users, totalPage2, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
Page: 2,
PerPage: 2,
})
require.Nil(t, err)
assert.Equal(t, 1, len(page2Users))
assert.Equal(t, uint(3), totalPage2) // Total should still be 3
// Verify that page1 and page2 contain different users
for _, p1User := range page1Users {
for _, p2User := range page2Users {
assert.NotEqual(t, p1User.ID, p2User.ID, "Users should not appear on multiple pages")
}
}
// Test 4: Filter by username
listUsers, totalListUsers, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
Page: 1,
PerPage: 10,
UserNameFilter: ptr.String("list-test-user2"),
})
require.Nil(t, err)
require.Len(t, listUsers, 1)
assert.Equal(t, uint(1), totalListUsers)
assert.Equal(t, "list-test-user2", listUsers[0].UserName)
// Test 5: Filter by email type and value
homeEmailUsers, totalHomeEmailUsers, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
Page: 1,
PerPage: 10,
EmailTypeFilter: ptr.String("home"),
EmailValueFilter: ptr.String("personal.user2@example.com"),
})
require.Nil(t, err)
require.Len(t, homeEmailUsers, 1)
assert.Equal(t, uint(1), totalHomeEmailUsers)
assert.Equal(t, users[1].ID, homeEmailUsers[0].ID)
assert.Equal(t, 2, len(homeEmailUsers[0].Emails))
// Test 6: Filter by email type and value - work emails
workEmailUsers, totalWorkEmailUsers, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
Page: 1,
PerPage: 10,
EmailTypeFilter: ptr.String("work"),
EmailValueFilter: ptr.String("different.user3@example.com"),
})
require.Nil(t, err)
assert.Len(t, workEmailUsers, 1)
assert.Equal(t, uint(1), totalWorkEmailUsers)
// Test 7: No results for non-matching filters
noUsers, totalNoUsers1, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
Page: 1,
PerPage: 10,
UserNameFilter: ptr.String("nonexistent"),
})
require.Nil(t, err)
assert.Empty(t, noUsers)
assert.Equal(t, uint(0), totalNoUsers1)
noUsers, totalNoUsers2, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
Page: 1,
PerPage: 10,
EmailTypeFilter: ptr.String("nonexistent"),
EmailValueFilter: ptr.String("nonexistent"),
})
require.Nil(t, err)
assert.Empty(t, noUsers)
assert.Equal(t, uint(0), totalNoUsers2)
}

View file

@ -2012,6 +2012,22 @@ type Datastore interface {
// Android
AndroidDatastore
// /////////////////////////////////////////////////////////////////////////////
// SCIM
// CreateScimUser creates a new SCIM user in the database
CreateScimUser(ctx context.Context, user *ScimUser) (uint, error)
// ScimUserByID retrieves a SCIM user by ID
ScimUserByID(ctx context.Context, id uint) (*ScimUser, error)
// ScimUserByUserName retrieves a SCIM user by username
ScimUserByUserName(ctx context.Context, userName string) (*ScimUser, error)
// ReplaceScimUser replaces an existing SCIM user in the database
ReplaceScimUser(ctx context.Context, user *ScimUser) error
// DeleteScimUser deletes a SCIM user from the database
DeleteScimUser(ctx context.Context, id uint) error
// ListScimUsers retrieves a list of SCIM users with optional filtering
ListScimUsers(ctx context.Context, opts ScimUsersListOptions) (users []ScimUser, totalResults uint, err error)
}
type AndroidDatastore interface {

41
server/fleet/scim.go Normal file
View file

@ -0,0 +1,41 @@
package fleet
// ScimUser represents a SCIM user in the database
type ScimUser struct {
ID uint `db:"id"`
ExternalID *string `db:"external_id"`
UserName string `db:"user_name"`
GivenName *string `db:"given_name"`
FamilyName *string `db:"family_name"`
Active *bool `db:"active"`
Emails []ScimUserEmail
}
func (su *ScimUser) AuthzType() string {
return "scim_user"
}
// ScimUserEmail represents an email address associated with a SCIM user
type ScimUserEmail struct {
ScimUserID uint `db:"scim_user_id"`
Email string `db:"email"`
Primary *bool `db:"primary"`
Type *string `db:"type"`
}
type ScimUsersListOptions struct {
// Which page to return (must be positive integer)
Page uint
// How many results per page (must be positive integer)
PerPage uint
// UserNameFilter filters by userName -- max of 1 response is expected
// Cannot be used with other filters.
UserNameFilter *string
// EmailTypeFilter and EmailValueFilter are needed to support Entra ID filter: emails[type eq "work"].value eq "user@contoso.com"
// https://learn.microsoft.com/en-us/entra/identity/app-provisioning/use-scim-to-provision-users-and-groups#users
// Cannot be used with other filters.
EmailTypeFilter *string
EmailValueFilter *string
}

View file

@ -1284,6 +1284,18 @@ type SetAndroidEnabledAndConfiguredFunc func(ctx context.Context, configured boo
type UpdateAndroidHostFunc func(ctx context.Context, host *fleet.AndroidHost, fromEnroll bool) error
type CreateScimUserFunc func(ctx context.Context, user *fleet.ScimUser) (uint, error)
type ScimUserByIDFunc func(ctx context.Context, id uint) (*fleet.ScimUser, error)
type ScimUserByUserNameFunc func(ctx context.Context, userName string) (*fleet.ScimUser, error)
type ReplaceScimUserFunc func(ctx context.Context, user *fleet.ScimUser) error
type DeleteScimUserFunc func(ctx context.Context, id uint) error
type ListScimUsersFunc func(ctx context.Context, opts fleet.ScimUsersListOptions) (users []fleet.ScimUser, totalResults uint, err error)
type DataStore struct {
HealthCheckFunc HealthCheckFunc
HealthCheckFuncInvoked bool
@ -3178,6 +3190,24 @@ type DataStore struct {
UpdateAndroidHostFunc UpdateAndroidHostFunc
UpdateAndroidHostFuncInvoked bool
CreateScimUserFunc CreateScimUserFunc
CreateScimUserFuncInvoked bool
ScimUserByIDFunc ScimUserByIDFunc
ScimUserByIDFuncInvoked bool
ScimUserByUserNameFunc ScimUserByUserNameFunc
ScimUserByUserNameFuncInvoked bool
ReplaceScimUserFunc ReplaceScimUserFunc
ReplaceScimUserFuncInvoked bool
DeleteScimUserFunc DeleteScimUserFunc
DeleteScimUserFuncInvoked bool
ListScimUsersFunc ListScimUsersFunc
ListScimUsersFuncInvoked bool
mu sync.Mutex
}
@ -7597,3 +7627,45 @@ func (s *DataStore) UpdateAndroidHost(ctx context.Context, host *fleet.AndroidHo
s.mu.Unlock()
return s.UpdateAndroidHostFunc(ctx, host, fromEnroll)
}
func (s *DataStore) CreateScimUser(ctx context.Context, user *fleet.ScimUser) (uint, error) {
s.mu.Lock()
s.CreateScimUserFuncInvoked = true
s.mu.Unlock()
return s.CreateScimUserFunc(ctx, user)
}
func (s *DataStore) ScimUserByID(ctx context.Context, id uint) (*fleet.ScimUser, error) {
s.mu.Lock()
s.ScimUserByIDFuncInvoked = true
s.mu.Unlock()
return s.ScimUserByIDFunc(ctx, id)
}
func (s *DataStore) ScimUserByUserName(ctx context.Context, userName string) (*fleet.ScimUser, error) {
s.mu.Lock()
s.ScimUserByUserNameFuncInvoked = true
s.mu.Unlock()
return s.ScimUserByUserNameFunc(ctx, userName)
}
func (s *DataStore) ReplaceScimUser(ctx context.Context, user *fleet.ScimUser) error {
s.mu.Lock()
s.ReplaceScimUserFuncInvoked = true
s.mu.Unlock()
return s.ReplaceScimUserFunc(ctx, user)
}
func (s *DataStore) DeleteScimUser(ctx context.Context, id uint) error {
s.mu.Lock()
s.DeleteScimUserFuncInvoked = true
s.mu.Unlock()
return s.DeleteScimUserFunc(ctx, id)
}
func (s *DataStore) ListScimUsers(ctx context.Context, opts fleet.ScimUsersListOptions) (users []fleet.ScimUser, totalResults uint, err error) {
s.mu.Lock()
s.ListScimUsersFuncInvoked = true
s.mu.Unlock()
return s.ListScimUsersFunc(ctx, opts)
}

View file

@ -2,6 +2,7 @@ package auth
import (
"context"
"net/http"
"github.com/fleetdm/fleet/v4/server/contexts/authz"
"github.com/fleetdm/fleet/v4/server/contexts/token"
@ -39,7 +40,7 @@ func AuthenticatedUser(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpo
return next(ctx, request)
}
// if not succesful, try again this time with errors
// if not successful, try again this time with errors
sessionKey, ok := token.FromContext(ctx)
if !ok {
return nil, fleet.NewAuthHeaderRequiredError("no auth token")
@ -67,3 +68,44 @@ func AuthenticatedUser(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpo
func UnauthenticatedRequest(_ fleet.Service, next endpoint.Endpoint) endpoint.Endpoint {
return log.Logged(next)
}
// errorHandler has the same signature as http.Error
type errorHandler func(w http.ResponseWriter, detail string, status int)
func AuthenticatedUserMiddleware(svc fleet.Service, errHandler errorHandler, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// first check if already successfully set
if v, ok := viewer.FromContext(r.Context()); ok {
if v.User.IsAdminForcedPasswordReset() {
errHandler(w, fleet.ErrPasswordResetRequired.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
return
}
// if not successful, try again this time with errors
sessionKey, ok := token.FromContext(r.Context())
if !ok {
errHandler(w, fleet.NewAuthHeaderRequiredError("no auth token").Error(), http.StatusUnauthorized)
return
}
v, err := AuthViewer(r.Context(), string(sessionKey), svc)
if err != nil {
errHandler(w, err.Error(), http.StatusUnauthorized)
return
}
if v.User.IsAdminForcedPasswordReset() {
errHandler(w, fleet.ErrPasswordResetRequired.Error(), http.StatusUnauthorized)
return
}
ctx := viewer.NewContext(r.Context(), *v)
if ac, ok := authz.FromContext(r.Context()); ok {
ac.SetAuthnMethod(authz.AuthnUserToken)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}

View file

@ -28,3 +28,11 @@ func SetRequestsContexts(svc fleet.Service) kithttp.RequestFunc {
return ctx
}
}
func SetRequestsContextMiddleware(svc fleet.Service, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := kithttp.PopulateRequestContext(r.Context(), r)
ctx = SetRequestsContexts(svc)(ctx, r)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

View file

@ -37,3 +37,10 @@ func LogRequestEnd(logger kitlog.Logger) func(context.Context, http.ResponseWrit
return ctx
}
}
func LogResponseEndMiddleware(logger kitlog.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
LogRequestEnd(logger)(r.Context(), w)
})
}