mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 13:37:30 +00:00
Add SCIM Users (#27551)
For #27287 Video explaining the PR: https://www.youtube.com/watch?v=ZHgFUAvrPEI This PR adds SCIM Users support for Okta. The goal is to first add Users/Groups support so that the remaining backend SCIM work can be done in parallel. This PR does not include the following, which will be added in later PRs - Changes file - Groups support for Okta - Full support for Entra ID - Integration tests # Checklist for submitter - [x] If database migrations are included, checked table schema to confirm autoupdate - For database migrations: - [x] Checked schema for all modified table for columns that will auto-update timestamps during migration. - [x] Confirmed that updating the timestamps is acceptable, and will not cause unwanted side effects. - [x] Ensured the correct collation is explicitly set for character columns (`COLLATE utf8mb4_unicode_ci`). - [x] Added/updated automated tests - [x] A detailed QA plan exists on the associated ticket (if it isn't there, work with the product group's QA engineer to add it) - [x] Manual QA for all new/changed functionality
This commit is contained in:
parent
94037e5e56
commit
2198fd8d65
17 changed files with 1923 additions and 4 deletions
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/WatchBeam/clock"
|
||||
"github.com/e-dard/netbug"
|
||||
"github.com/fleetdm/fleet/v4/ee/server/licensing"
|
||||
"github.com/fleetdm/fleet/v4/ee/server/scim"
|
||||
eeservice "github.com/fleetdm/fleet/v4/ee/server/service"
|
||||
"github.com/fleetdm/fleet/v4/ee/server/service/digicert"
|
||||
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
|
||||
|
|
@ -1171,11 +1172,14 @@ the way that the Fleet server works.
|
|||
}
|
||||
}
|
||||
|
||||
// SCEP proxy (for NDES, etc.)
|
||||
if license.IsPremium() {
|
||||
// SCEP proxy (for NDES, etc.)
|
||||
if err = service.RegisterSCEPProxy(rootMux, ds, logger, nil); err != nil {
|
||||
initFatal(err, "setup SCEP proxy")
|
||||
}
|
||||
if err = scim.RegisterSCIM(rootMux, ds, svc, logger); err != nil {
|
||||
initFatal(err, "setup SCIM")
|
||||
}
|
||||
}
|
||||
|
||||
if config.Prometheus.BasicAuth.Username != "" && config.Prometheus.BasicAuth.Password != "" {
|
||||
|
|
|
|||
102
docs/Contributing/MDM-SCIM-integration.md
Normal file
102
docs/Contributing/MDM-SCIM-integration.md
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
# SCIM (System for Cross-domain Identity Management) integration
|
||||
|
||||
## Reference docs
|
||||
|
||||
- [scim.cloud](https://scim.cloud/)
|
||||
- [SCIM: Core Schema (RFC7643)](https://datatracker.ietf.org/doc/html/rfc7643)
|
||||
- [SCIM: Protocol (RFC7644)](https://datatracker.ietf.org/doc/html/rfc7644)
|
||||
- [scim Go library](https://github.com/elimity-com/scim)
|
||||
|
||||
## Okta integration
|
||||
|
||||
- https://developer.okta.com/docs/guides/scim-provisioning-integration-prepare/main/
|
||||
|
||||
### Testing Okta integration
|
||||
|
||||
First, create at least one SCIM user:
|
||||
|
||||
```
|
||||
POST https://localhost:8080/api/latest/fleet/scim/Users
|
||||
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
"userName": "test.user@okta.local",
|
||||
"name": {
|
||||
"givenName": "Test",
|
||||
"familyName": "User"
|
||||
},
|
||||
"emails": [{
|
||||
"primary": true,
|
||||
"value": "test.user@okta.local",
|
||||
"type": "work"
|
||||
}],
|
||||
"active": true
|
||||
}
|
||||
```
|
||||
|
||||
Run test using [Runscope](https://www.runscope.com/). See [instructions](https://developer.okta.com/docs/guides/scim-provisioning-integration-prepare/main/#test-your-scim-api).
|
||||
|
||||
## Entra ID integration
|
||||
- [SCIM guide](https://learn.microsoft.com/en-us/entra/identity/app-provisioning/use-scim-to-provision-users-and-groups)
|
||||
- [SCIM validator](https://scimvalidator.microsoft.com/)
|
||||
- Only test attributes that we implemented
|
||||
|
||||
### Testing Entra ID integration
|
||||
|
||||
Use [scimvalidator.microsoft.com](https://scimvalidator.microsoft.com/). Only test the attributes that we have implemented. To see our supported attributes, check the schema:
|
||||
|
||||
```
|
||||
GET https://localhost:8080/api/latest/fleet/scim/Schemas
|
||||
```
|
||||
|
||||
## Authentication
|
||||
|
||||
We use same authentication as API. HTTP header: `Authorization: Bearer xyz`
|
||||
|
||||
## Diagrams
|
||||
|
||||
```mermaid
|
||||
---
|
||||
title: Initial DB schema (not kept up to date)
|
||||
---
|
||||
erDiagram
|
||||
HOST_SCIM_USER {
|
||||
host_id uint PK
|
||||
scim_user_id uint PK "FK"
|
||||
}
|
||||
SCIM_USERS {
|
||||
id uint PK
|
||||
external_id *string "Index"
|
||||
user_name string "Unique"
|
||||
given_name *string
|
||||
family_name *string
|
||||
active *bool
|
||||
}
|
||||
SCIM_USER_EMAILS {
|
||||
id uint PK
|
||||
scim_user_id uint FK
|
||||
type *string "Index"
|
||||
email string "Index"
|
||||
primary *bool
|
||||
}
|
||||
SCIM_USER_GROUP {
|
||||
scim_user_id string PK "FK"
|
||||
group_id uint PK "FK"
|
||||
}
|
||||
SCIM_GROUPS {
|
||||
id uint PK
|
||||
external_id *string "Index"
|
||||
display_name string "Index"
|
||||
}
|
||||
HOST_SCIM_USER }o--|| SCIM_USERS : "multiple hosts can have the same SCIM user"
|
||||
SCIM_USERS ||--o{ SCIM_USER_GROUP: "zero-to-many"
|
||||
SCIM_USER_GROUP }o--|| SCIM_GROUPS: "zero-to-many"
|
||||
SCIM_USERS ||--o{ SCIM_USER_EMAILS: "zero-to-many"
|
||||
COMMENT {
|
||||
_ _ "created_at and updated_at columns not shown"
|
||||
}
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Okta and Entra ID do not support nested groups
|
||||
180
ee/server/scim/scim.go
Normal file
180
ee/server/scim/scim.go
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
package scim
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/elimity-com/scim"
|
||||
"github.com/elimity-com/scim/errors"
|
||||
"github.com/elimity-com/scim/optional"
|
||||
"github.com/elimity-com/scim/schema"
|
||||
"github.com/fleetdm/fleet/v4/server/authz"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/auth"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/log"
|
||||
kitlog "github.com/go-kit/log"
|
||||
"github.com/go-kit/log/level"
|
||||
)
|
||||
|
||||
const (
|
||||
maxResults = 1000
|
||||
)
|
||||
|
||||
func RegisterSCIM(
|
||||
mux *http.ServeMux,
|
||||
ds fleet.Datastore,
|
||||
svc fleet.Service,
|
||||
logger kitlog.Logger,
|
||||
) error {
|
||||
config := scim.ServiceProviderConfig{
|
||||
// TODO: DocumentationURI and Authentication scheme
|
||||
DocumentationURI: optional.NewString("https://fleetdm.com/docs/get-started/why-fleet"),
|
||||
SupportFiltering: true,
|
||||
SupportPatch: true,
|
||||
MaxResults: maxResults,
|
||||
}
|
||||
|
||||
// The common attributes are id, externalId, and meta.
|
||||
// In practice only meta.resourceType is required, while the other four (created, lastModified, location, and version) are not strictly required.
|
||||
userSchema := schema.Schema{
|
||||
ID: "urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
Name: optional.NewString("User"),
|
||||
Description: optional.NewString("SCIM User"),
|
||||
Attributes: []schema.CoreAttribute{
|
||||
schema.SimpleCoreAttribute(schema.SimpleStringParams(schema.StringParams{
|
||||
Name: "userName",
|
||||
Required: true,
|
||||
Uniqueness: schema.AttributeUniquenessServer(),
|
||||
})),
|
||||
schema.ComplexCoreAttribute(schema.ComplexParams{
|
||||
Description: optional.NewString("The components of the user's real name. Providers MAY return just the full name as a single string in the formatted sub-attribute, or they MAY return just the individual component attributes using the other sub-attributes, or they MAY return both. If both variants are returned, they SHOULD be describing the same name, with the formatted name indicating how the component attributes should be combined."),
|
||||
Name: "name",
|
||||
SubAttributes: []schema.SimpleParams{
|
||||
schema.SimpleStringParams(schema.StringParams{
|
||||
Description: optional.NewString("The family name of the User, or last name in most Western languages (e.g., 'Jensen' given the full name 'Ms. Barbara J Jensen, III')."),
|
||||
Name: "familyName",
|
||||
}),
|
||||
schema.SimpleStringParams(schema.StringParams{
|
||||
Description: optional.NewString("The given name of the User, or first name in most Western languages (e.g., 'Barbara' given the full name 'Ms. Barbara J Jensen, III')."),
|
||||
Name: "givenName",
|
||||
}),
|
||||
},
|
||||
}),
|
||||
schema.ComplexCoreAttribute(schema.ComplexParams{
|
||||
Description: optional.NewString("Email addresses for the user. The value SHOULD be canonicalized by the service provider, e.g., 'bjensen@example.com' instead of 'bjensen@EXAMPLE.COM'. Canonical type values of 'work', 'home', and 'other'."),
|
||||
MultiValued: true,
|
||||
Name: "emails",
|
||||
SubAttributes: []schema.SimpleParams{
|
||||
schema.SimpleStringParams(schema.StringParams{
|
||||
Description: optional.NewString("Email addresses for the user. The value SHOULD be canonicalized by the service provider, e.g., 'bjensen@example.com' instead of 'bjensen@EXAMPLE.COM'. Canonical type values of 'work', 'home', and 'other'."),
|
||||
Name: "value",
|
||||
}),
|
||||
schema.SimpleStringParams(schema.StringParams{
|
||||
CanonicalValues: []string{"work", "home", "other"},
|
||||
Description: optional.NewString("A label indicating the attribute's function, e.g., 'work' or 'home'."),
|
||||
Name: "type",
|
||||
}),
|
||||
schema.SimpleBooleanParams(schema.BooleanParams{
|
||||
Description: optional.NewString("A Boolean value indicating the 'primary' or preferred attribute value for this attribute, e.g., the preferred mailing address or primary email address. The primary attribute value 'true' MUST appear no more than once."),
|
||||
Name: "primary",
|
||||
}),
|
||||
},
|
||||
}),
|
||||
schema.SimpleCoreAttribute(schema.SimpleBooleanParams(schema.BooleanParams{
|
||||
Description: optional.NewString("A Boolean value indicating the User's administrative status."),
|
||||
Name: "active",
|
||||
})),
|
||||
},
|
||||
}
|
||||
|
||||
scimLogger := kitlog.With(logger, "component", "SCIM")
|
||||
resourceTypes := []scim.ResourceType{
|
||||
{
|
||||
ID: optional.NewString("User"),
|
||||
Name: "User",
|
||||
Endpoint: "/Users",
|
||||
Description: optional.NewString("User Account"),
|
||||
Schema: userSchema,
|
||||
Handler: NewUserHandler(ds, scimLogger),
|
||||
},
|
||||
}
|
||||
|
||||
serverArgs := &scim.ServerArgs{
|
||||
ServiceProviderConfig: &config,
|
||||
ResourceTypes: resourceTypes,
|
||||
}
|
||||
|
||||
serverOpts := []scim.ServerOption{
|
||||
scim.WithLogger(&scimErrorLogger{Logger: scimLogger}),
|
||||
}
|
||||
|
||||
server, err := scim.NewServer(serverArgs, serverOpts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scimErrorHandler := func(w http.ResponseWriter, detail string, status int) {
|
||||
errorHandler(w, scimLogger, detail, status)
|
||||
}
|
||||
authorizer, err := authz.NewAuthorizer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Add APM/OpenTelemetry tracing and Prometheus middleware
|
||||
applyMiddleware := func(prefix string, server http.Handler) http.Handler {
|
||||
handler := http.StripPrefix(prefix, server)
|
||||
handler = AuthorizationMiddleware(authorizer, scimLogger, handler)
|
||||
handler = auth.AuthenticatedUserMiddleware(svc, scimErrorHandler, handler)
|
||||
handler = log.LogResponseEndMiddleware(scimLogger, handler)
|
||||
handler = auth.SetRequestsContextMiddleware(svc, handler)
|
||||
return handler
|
||||
}
|
||||
|
||||
mux.Handle("/api/v1/fleet/scim/", applyMiddleware("/api/v1/fleet/scim", server))
|
||||
mux.Handle("/api/latest/fleet/scim/", applyMiddleware("/api/latest/fleet/scim", server))
|
||||
return nil
|
||||
}
|
||||
|
||||
func AuthorizationMiddleware(authorizer *authz.Authorizer, logger kitlog.Logger, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := authorizer.Authorize(r.Context(), &fleet.ScimUser{}, fleet.ActionWrite)
|
||||
if err != nil {
|
||||
errorHandler(w, logger, err.Error(), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func errorHandler(w http.ResponseWriter, logger kitlog.Logger, detail string, status int) {
|
||||
scimErr := errors.ScimError{
|
||||
Status: status,
|
||||
Detail: detail,
|
||||
}
|
||||
raw, err := json.Marshal(scimErr)
|
||||
if err != nil {
|
||||
level.Error(logger).Log("msg", "failed marshaling scim error", "scimError", scimErr, "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/scim+json")
|
||||
w.WriteHeader(scimErr.Status)
|
||||
_, err = w.Write(raw)
|
||||
if err != nil {
|
||||
level.Error(logger).Log("msg", "failed writing response", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
type scimErrorLogger struct {
|
||||
kitlog.Logger
|
||||
}
|
||||
|
||||
var _ scim.Logger = &scimErrorLogger{}
|
||||
|
||||
func (l *scimErrorLogger) Error(args ...interface{}) {
|
||||
level.Error(l.Logger).Log(
|
||||
"error", fmt.Sprint(args...),
|
||||
)
|
||||
}
|
||||
455
ee/server/scim/users.go
Normal file
455
ee/server/scim/users.go
Normal file
|
|
@ -0,0 +1,455 @@
|
|||
package scim
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/elimity-com/scim"
|
||||
"github.com/elimity-com/scim/errors"
|
||||
"github.com/elimity-com/scim/optional"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
kitlog "github.com/go-kit/log"
|
||||
"github.com/go-kit/log/level"
|
||||
"github.com/scim2/filter-parser/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
// Common attributes: https://datatracker.ietf.org/doc/html/rfc7643#section-3.1
|
||||
externalIdAttr = "externalId"
|
||||
|
||||
// User attributes: https://datatracker.ietf.org/doc/html/rfc7643#section-4.1
|
||||
userNameAttr = "userName"
|
||||
nameAttr = "name"
|
||||
givenNameAttr = "givenName"
|
||||
familyNameAttr = "familyName"
|
||||
activeAttr = "active"
|
||||
emailsAttr = "emails"
|
||||
)
|
||||
|
||||
type UserHandler struct {
|
||||
ds fleet.Datastore
|
||||
logger kitlog.Logger
|
||||
}
|
||||
|
||||
// Compile-time check
|
||||
var _ scim.ResourceHandler = &UserHandler{}
|
||||
|
||||
func NewUserHandler(ds fleet.Datastore, logger kitlog.Logger) scim.ResourceHandler {
|
||||
return &UserHandler{ds: ds, logger: logger}
|
||||
}
|
||||
|
||||
func (u *UserHandler) Create(r *http.Request, attributes scim.ResourceAttributes) (scim.Resource, error) {
|
||||
// Check for userName uniqueness
|
||||
userName, err := getRequiredResource[string](attributes, userNameAttr)
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to get userName", "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
_, err = u.ds.ScimUserByUserName(r.Context(), userName)
|
||||
if !fleet.IsNotFound(err) {
|
||||
level.Info(u.logger).Log("msg", "user already exists", userNameAttr, userName)
|
||||
return scim.Resource{}, errors.ScimErrorUniqueness
|
||||
}
|
||||
|
||||
user, err := createUserFromAttributes(attributes)
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to create user from attributes", userNameAttr, userName, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
user.ID, err = u.ds.CreateScimUser(r.Context(), user)
|
||||
if err != nil {
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
||||
return createUserResource(user), nil
|
||||
}
|
||||
|
||||
func createUserFromAttributes(attributes scim.ResourceAttributes) (*fleet.ScimUser, error) {
|
||||
user := fleet.ScimUser{}
|
||||
var err error
|
||||
user.UserName, err = getRequiredResource[string](attributes, userNameAttr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.ExternalID, err = getOptionalResource[string](attributes, externalIdAttr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.Active, err = getOptionalResource[bool](attributes, activeAttr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name, err := getComplexResource(attributes, nameAttr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.FamilyName, err = getOptionalResource[string](name, familyNameAttr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.GivenName, err = getOptionalResource[string](name, givenNameAttr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
emails, err := getComplexResourceSlice(attributes, emailsAttr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userEmails := make([]fleet.ScimUserEmail, 0, len(emails))
|
||||
for _, email := range emails {
|
||||
userEmail := fleet.ScimUserEmail{}
|
||||
userEmail.Email, err = getRequiredResource[string](email, "value")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Service providers SHOULD canonicalize the value according to [RFC5321]
|
||||
// https://datatracker.ietf.org/doc/html/rfc7643#section-4.1.2
|
||||
userEmail.Email, err = normalizeEmail(userEmail.Email)
|
||||
if err != nil {
|
||||
return nil, errors.ScimErrorBadParams([]string{"value"})
|
||||
}
|
||||
userEmail.Type, err = getOptionalResource[string](email, "type")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userEmail.Primary, err = getOptionalResource[bool](email, "primary")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userEmails = append(userEmails, userEmail)
|
||||
}
|
||||
user.Emails = userEmails
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func getRequiredResource[T string | bool](attributes scim.ResourceAttributes, key string) (T, error) {
|
||||
var val T
|
||||
valIntf, ok := attributes[key]
|
||||
if !ok || valIntf == nil {
|
||||
return val, errors.ScimErrorBadParams([]string{key})
|
||||
}
|
||||
val, ok = valIntf.(T)
|
||||
if !ok {
|
||||
return val, errors.ScimErrorBadParams([]string{key})
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func getOptionalResource[T string | bool](attributes scim.ResourceAttributes, key string) (*T, error) {
|
||||
var valPtr *T
|
||||
valIntf, ok := attributes[key]
|
||||
if ok && valIntf != nil {
|
||||
val, ok := valIntf.(T)
|
||||
if !ok {
|
||||
return nil, errors.ScimErrorBadParams([]string{key})
|
||||
}
|
||||
valPtr = &val
|
||||
}
|
||||
return valPtr, nil
|
||||
}
|
||||
|
||||
func getComplexResource(attributes scim.ResourceAttributes, key string) (map[string]interface{}, error) {
|
||||
valIntf, ok := attributes[key]
|
||||
if ok && valIntf != nil {
|
||||
val, ok := valIntf.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, errors.ScimErrorBadParams([]string{key})
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func getComplexResourceSlice(attributes scim.ResourceAttributes, key string) ([]map[string]interface{}, error) {
|
||||
valIntf, ok := attributes[key]
|
||||
if ok && valIntf != nil {
|
||||
valSliceIntf, ok := valIntf.([]interface{})
|
||||
if !ok {
|
||||
return nil, errors.ScimErrorBadParams([]string{key})
|
||||
}
|
||||
val := make([]map[string]interface{}, 0, len(valSliceIntf))
|
||||
for _, v := range valSliceIntf {
|
||||
valMap, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, errors.ScimErrorBadParams([]string{key})
|
||||
}
|
||||
if len(valMap) > 0 {
|
||||
val = append(val, valMap)
|
||||
}
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) Get(r *http.Request, id string) (scim.Resource, error) {
|
||||
idUint, err := strconv.ParseUint(id, 10, 64)
|
||||
if err != nil {
|
||||
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
}
|
||||
|
||||
user, err := u.ds.ScimUserByID(r.Context(), uint(idUint))
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(u.logger).Log("msg", "failed to find user", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(u.logger).Log("msg", "failed to get user", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
||||
return createUserResource(user), nil
|
||||
}
|
||||
|
||||
func createUserResource(user *fleet.ScimUser) scim.Resource {
|
||||
userResource := scim.Resource{}
|
||||
userResource.ID = fmt.Sprintf("%d", user.ID)
|
||||
if user.ExternalID != nil {
|
||||
userResource.ExternalID = optional.NewString(*user.ExternalID)
|
||||
}
|
||||
userResource.Attributes = scim.ResourceAttributes{}
|
||||
userResource.Attributes[userNameAttr] = user.UserName
|
||||
if user.Active != nil {
|
||||
userResource.Attributes[activeAttr] = *user.Active
|
||||
}
|
||||
if user.FamilyName != nil || user.GivenName != nil {
|
||||
userResource.Attributes[nameAttr] = make(scim.ResourceAttributes)
|
||||
if user.FamilyName != nil {
|
||||
userResource.Attributes[nameAttr].(scim.ResourceAttributes)[familyNameAttr] = *user.FamilyName
|
||||
}
|
||||
if user.GivenName != nil {
|
||||
userResource.Attributes[nameAttr].(scim.ResourceAttributes)[givenNameAttr] = *user.GivenName
|
||||
}
|
||||
}
|
||||
if len(user.Emails) > 0 {
|
||||
emails := make([]scim.ResourceAttributes, 0, len(user.Emails))
|
||||
for _, email := range user.Emails {
|
||||
emailResource := make(scim.ResourceAttributes)
|
||||
emailResource["value"] = email.Email
|
||||
if email.Type != nil {
|
||||
emailResource["type"] = *email.Type
|
||||
}
|
||||
if email.Primary != nil {
|
||||
emailResource["primary"] = *email.Primary
|
||||
}
|
||||
emails = append(emails, emailResource)
|
||||
}
|
||||
userResource.Attributes[emailsAttr] = emails
|
||||
}
|
||||
return userResource
|
||||
}
|
||||
|
||||
// GetAll
|
||||
// Per RFC7644 3.4.2, SHOULD ignore any query parameters they do not recognize instead of rejecting the query for versioning compatibility reasons
|
||||
// https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.2
|
||||
//
|
||||
// Providers MUST decline to filter results if the specified filter operation is not recognized and return an HTTP 400 error with a
|
||||
// "scimType" error of "invalidFilter" and an appropriate human-readable response as per Section 3.12. For example, if a client specified an
|
||||
// unsupported operator named 'regex', the service provider should specify an error response description identifying the client error,
|
||||
// e.g., 'The operator 'regex' is not supported.'
|
||||
//
|
||||
// If a SCIM service provider determines that too many results would be returned the server base URI, the server SHALL reject the request by
|
||||
// returning an HTTP response with HTTP status code 400 (Bad Request) and JSON attribute "scimType" set to "tooMany" (see Table 9).
|
||||
//
|
||||
// totalResults: The total number of results returned by the list or query operation. The value may be larger than the number of
|
||||
// resources returned, such as when returning a single page (see Section 3.4.2.4) of results where multiple pages are available.
|
||||
func (u *UserHandler) GetAll(r *http.Request, params scim.ListRequestParams) (scim.Page, error) {
|
||||
page := params.StartIndex
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
count := params.Count
|
||||
if count > maxResults {
|
||||
return scim.Page{}, errors.ScimErrorTooMany
|
||||
}
|
||||
if count < 1 {
|
||||
count = maxResults
|
||||
}
|
||||
|
||||
opts := fleet.ScimUsersListOptions{
|
||||
Page: uint(page), // nolint:gosec // ignore G115
|
||||
PerPage: uint(count), // nolint:gosec // ignore G115
|
||||
}
|
||||
resourceFilter := r.URL.Query().Get("filter")
|
||||
if resourceFilter != "" {
|
||||
expr, err := filter.ParseAttrExp([]byte(resourceFilter))
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to parse filter", "filter", resourceFilter, "err", err)
|
||||
return scim.Page{}, errors.ScimErrorInvalidFilter
|
||||
}
|
||||
if !strings.EqualFold(expr.AttributePath.String(), "userName") || expr.Operator != "eq" {
|
||||
level.Info(u.logger).Log("msg", "unsupported filter", "filter", resourceFilter)
|
||||
return scim.Page{}, nil
|
||||
}
|
||||
userName, ok := expr.CompareValue.(string)
|
||||
if !ok {
|
||||
level.Error(u.logger).Log("msg", "unsupported value", "value", expr.CompareValue)
|
||||
return scim.Page{}, nil
|
||||
}
|
||||
|
||||
// Decode URL-encoded characters in userName, which is required to pass Microsoft Entra ID SCIM Validator
|
||||
userName, err = url.QueryUnescape(userName)
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to decode userName", "userName", userName, "err", err)
|
||||
return scim.Page{}, nil
|
||||
}
|
||||
opts.UserNameFilter = &userName
|
||||
}
|
||||
users, totalResults, err := u.ds.ListScimUsers(r.Context(), opts)
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to list users", "err", err)
|
||||
return scim.Page{}, err
|
||||
}
|
||||
|
||||
result := scim.Page{
|
||||
TotalResults: int(totalResults), // nolint:gosec // ignore G115
|
||||
Resources: make([]scim.Resource, 0, len(users)),
|
||||
}
|
||||
for _, user := range users {
|
||||
result.Resources = append(result.Resources, createUserResource(&user))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (u *UserHandler) Replace(r *http.Request, id string, attributes scim.ResourceAttributes) (scim.Resource, error) {
|
||||
idUint, err := strconv.ParseUint(id, 10, 64)
|
||||
if err != nil {
|
||||
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
}
|
||||
|
||||
user, err := createUserFromAttributes(attributes)
|
||||
if err != nil {
|
||||
level.Error(u.logger).Log("msg", "failed to create user from attributes", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
user.ID = uint(idUint)
|
||||
err = u.ds.ReplaceScimUser(r.Context(), user)
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(u.logger).Log("msg", "failed to find user to replace", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(u.logger).Log("msg", "failed to replace user", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
||||
return createUserResource(user), nil
|
||||
}
|
||||
|
||||
// Delete
|
||||
// https://datatracker.ietf.org/doc/html/rfc7644#section-3.6
|
||||
// MUST return a 404 (Not Found) error code for all operations associated with the previously deleted resource
|
||||
func (u *UserHandler) Delete(r *http.Request, id string) error {
|
||||
idUint, err := strconv.ParseUint(id, 10, 64)
|
||||
if err != nil {
|
||||
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
||||
return errors.ScimErrorResourceNotFound(id)
|
||||
}
|
||||
err = u.ds.DeleteScimUser(r.Context(), uint(idUint))
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(u.logger).Log("msg", "failed to find user to delete", "id", id)
|
||||
return errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(u.logger).Log("msg", "failed to delete user", "id", id, "err", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Patch
|
||||
// Okta only requires patching the "active" attribute:
|
||||
// https://developer.okta.com/docs/api/openapi/okta-scim/guides/scim-20/#update-a-specific-user-patch
|
||||
func (u *UserHandler) Patch(r *http.Request, id string, operations []scim.PatchOperation) (scim.Resource, error) {
|
||||
idUint, err := strconv.ParseUint(id, 10, 64)
|
||||
if err != nil {
|
||||
level.Info(u.logger).Log("msg", "failed to parse id", "id", id, "err", err)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
}
|
||||
user, err := u.ds.ScimUserByID(r.Context(), uint(idUint))
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(u.logger).Log("msg", "failed to find user to patch", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(u.logger).Log("msg", "failed to get user to patch", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
||||
for _, op := range operations {
|
||||
if op.Op != "replace" {
|
||||
level.Info(u.logger).Log("msg", "unsupported patch operation", "op", op.Op)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
switch {
|
||||
case op.Path == nil:
|
||||
newValues, ok := op.Value.(map[string]interface{})
|
||||
if !ok {
|
||||
level.Info(u.logger).Log("msg", "unsupported patch value", "value", op.Value)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
if len(newValues) != 1 {
|
||||
level.Info(u.logger).Log("msg", "too many patch values", "value", op.Value)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
active, err := getRequiredResource[bool](newValues, activeAttr)
|
||||
if err != nil {
|
||||
level.Info(u.logger).Log("msg", "failed to get active value", "value", op.Value)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
user.Active = &active
|
||||
case op.Path.String() == activeAttr:
|
||||
active, ok := op.Value.(bool)
|
||||
if !ok {
|
||||
level.Error(u.logger).Log("msg", "unsupported 'active' patch value", "value", op.Value)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
user.Active = &active
|
||||
default:
|
||||
level.Info(u.logger).Log("msg", "unsupported patch path", "path", op.Path)
|
||||
return scim.Resource{}, errors.ScimErrorBadParams([]string{fmt.Sprintf("%v", op)})
|
||||
}
|
||||
}
|
||||
|
||||
err = u.ds.ReplaceScimUser(r.Context(), user)
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
level.Info(u.logger).Log("msg", "failed to find user to patch", "id", id)
|
||||
return scim.Resource{}, errors.ScimErrorResourceNotFound(id)
|
||||
case err != nil:
|
||||
level.Error(u.logger).Log("msg", "failed to patch user", "id", id, "err", err)
|
||||
return scim.Resource{}, err
|
||||
}
|
||||
|
||||
return createUserResource(user), nil
|
||||
}
|
||||
|
||||
// normalizeEmail
|
||||
// The local-part of a mailbox MUST BE treated as case sensitive.
|
||||
// Mailbox domains follow normal DNS rules and are hence not case sensitive.
|
||||
// https://datatracker.ietf.org/doc/html/rfc5321#section-2.4
|
||||
func normalizeEmail(email string) (string, error) {
|
||||
email = removeWhitespace(email)
|
||||
emailParts := strings.SplitN(email, "@", 2)
|
||||
if len(emailParts) != 2 {
|
||||
return "", fmt.Errorf("invalid email %s", email)
|
||||
}
|
||||
emailParts[1] = strings.ToLower(emailParts[1])
|
||||
return strings.Join(emailParts, "@"), nil
|
||||
}
|
||||
|
||||
func removeWhitespace(str string) string {
|
||||
return strings.Map(func(r rune) rune {
|
||||
if unicode.IsSpace(r) {
|
||||
return -1
|
||||
}
|
||||
return r
|
||||
}, str)
|
||||
}
|
||||
4
go.mod
4
go.mod
|
|
@ -181,12 +181,15 @@ require (
|
|||
github.com/cyphar/filepath-securejoin v0.2.5 // indirect
|
||||
github.com/dgraph-io/ristretto v0.1.0 // indirect
|
||||
github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 // indirect
|
||||
github.com/di-wu/parser v0.2.2 // indirect
|
||||
github.com/di-wu/xsd-datetime v1.0.0 // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/go-connections v0.4.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/edsrzf/mmap-go v1.1.0 // indirect
|
||||
github.com/elastic/go-sysinfo v1.11.2 // indirect
|
||||
github.com/elastic/go-windows v1.0.1 // indirect
|
||||
github.com/elimity-com/scim v0.0.0-20240320110924-172bf2aee9c8 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/fatih/structs v1.1.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
|
|
@ -249,6 +252,7 @@ require (
|
|||
github.com/prometheus/procfs v0.12.0 // indirect
|
||||
github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect
|
||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||
github.com/scim2/filter-parser/v2 v2.2.0 // indirect
|
||||
github.com/secDre4mer/pkcs7 v0.0.0-20240322103146-665324a4461d // indirect
|
||||
github.com/secure-systems-lab/go-securesystemslib v0.5.0 // indirect
|
||||
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
|
||||
|
|
|
|||
8
go.sum
8
go.sum
|
|
@ -219,6 +219,10 @@ github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUn
|
|||
github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
|
||||
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g=
|
||||
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
|
||||
github.com/di-wu/parser v0.2.2 h1:I9oHJ8spBXOeL7Wps0ffkFFFiXJf/pk7NX9lcAMqRMU=
|
||||
github.com/di-wu/parser v0.2.2/go.mod h1:SLp58pW6WamdmznrVRrw2NTyn4wAvT9rrEFynKX7nYo=
|
||||
github.com/di-wu/xsd-datetime v1.0.0 h1:vZoGNkbzpBNoc+JyfVLEbutNDNydYV8XwHeV7eUJoxI=
|
||||
github.com/di-wu/xsd-datetime v1.0.0/go.mod h1:i3iEhrP3WchwseOBeIdW/zxeoleXTOzx1WyDXgdmOww=
|
||||
github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e h1:vUmf0yezR0y7jJ5pceLHthLaYf4bA5T14B6q39S4q2Q=
|
||||
github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e/go.mod h1:YTIHhz/QFSYnu/EhlF2SpU2Uk+32abacUYA5ZPljz1A=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
|
|
@ -246,6 +250,8 @@ github.com/elazarl/go-bindata-assetfs v1.0.1 h1:m0kkaHRKEu7tUIUFVwhGGGYClXvyl4RE
|
|||
github.com/elazarl/go-bindata-assetfs v1.0.1/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4=
|
||||
github.com/elazarl/goproxy v1.2.1 h1:njjgvO6cRG9rIqN2ebkqy6cQz2Njkx7Fsfv/zIZqgug=
|
||||
github.com/elazarl/goproxy v1.2.1/go.mod h1:YfEbZtqP4AetfO6d40vWchF3znWX7C7Vd6ZMfdL8z64=
|
||||
github.com/elimity-com/scim v0.0.0-20240320110924-172bf2aee9c8 h1:0+BTyxIYgiVAry/P5s8R4dYuLkhB9Nhso8ogFWNr4IQ=
|
||||
github.com/elimity-com/scim v0.0.0-20240320110924-172bf2aee9c8/go.mod h1:JkjcmqbLW+khwt2fmBPJFBhx2zGZ8XobRZ+O0VhlwWo=
|
||||
github.com/emirpasic/gods v1.12.0/go.mod h1:YfzfFFoVP/catgzJb4IKIqXjX78Ha8FMSDh3ymbK86o=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
|
|
@ -745,6 +751,8 @@ github.com/saferwall/pe v1.5.5 h1:GGbzKjXDm7i+1K6riOgtgblyTdRmTbr3r11IzjovAK8=
|
|||
github.com/saferwall/pe v1.5.5/go.mod h1:mJx+PuptmNpoPFBNhWs/uDMFL/kTHVZIkg0d4OUJFbQ=
|
||||
github.com/sassoftware/relic/v8 v8.0.1 h1:uYUoaoTQMs67up8/46NgrSxSftgfY4VWBusDVg56k7I=
|
||||
github.com/sassoftware/relic/v8 v8.0.1/go.mod h1:s/MwugRcovgYcNJNOyvLfqRHDX7iArHtFtUR9kEodz8=
|
||||
github.com/scim2/filter-parser/v2 v2.2.0 h1:QGadEcsmypxg8gYChRSM2j1edLyE/2j72j+hdmI4BJM=
|
||||
github.com/scim2/filter-parser/v2 v2.2.0/go.mod h1:jWnkDToqX/Y0ugz0P5VvpVEUKcWcyHHj+X+je9ce5JA=
|
||||
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 h1:rc/CcqLH3lh8n+csdOuDfP+NuykE0U6AeYSJJHKDgSg=
|
||||
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9/go.mod h1:a/83NAfUXvEuLpmxDssAXxgUgrEy12MId3Wd7OTs76s=
|
||||
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
|
||||
|
|
|
|||
|
|
@ -1046,3 +1046,13 @@ allow {
|
|||
subject.global_role == admin
|
||||
action == [read, write][_]
|
||||
}
|
||||
|
||||
##
|
||||
# SCIM (System for Cross-domain Identity Management)
|
||||
##
|
||||
# Global admins and maintainers can access SCIM.
|
||||
allow {
|
||||
object.type == "scim_user"
|
||||
subject.global_role == [admin, maintainer][_]
|
||||
action == write
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,78 @@
|
|||
package tables
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func init() {
|
||||
MigrationClient.AddMigration(Up_20250331042354, Down_20250331042354)
|
||||
}
|
||||
|
||||
func Up_20250331042354(tx *sql.Tx) error {
|
||||
_, err := tx.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS scim_users (
|
||||
id int UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
external_id VARCHAR(255) NULL,
|
||||
user_name VARCHAR(255) NOT NULL,
|
||||
given_name VARCHAR(255) NULL,
|
||||
family_name VARCHAR(255) NULL,
|
||||
active TINYINT(1) NULL,
|
||||
created_at DATETIME(6) NOT NULL DEFAULT NOW(6),
|
||||
updated_at DATETIME(6) NOT NULL DEFAULT NOW(6) ON UPDATE NOW(6),
|
||||
UNIQUE KEY idx_scim_users_user_name (user_name),
|
||||
KEY idx_scim_users_external_id (external_id)
|
||||
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS host_scim_user (
|
||||
host_id INT UNSIGNED NOT NULL,
|
||||
scim_user_id INT UNSIGNED NOT NULL,
|
||||
created_at DATETIME(6) NOT NULL DEFAULT NOW(6),
|
||||
PRIMARY KEY (host_id, scim_user_id),
|
||||
CONSTRAINT fk_host_scim_scim_user_id FOREIGN KEY (scim_user_id) REFERENCES scim_users (id) ON DELETE CASCADE
|
||||
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE if NOT EXISTS scim_user_emails (
|
||||
-- Using BIGINT because we clear and repopulate the emails frequently (during user update)
|
||||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
scim_user_id INT UNSIGNED NOT NULL,
|
||||
email VARCHAR(255) NOT NULL,
|
||||
` + "`primary`" + ` TINYINT(1) NULL,
|
||||
type VARCHAR(31) NULL,
|
||||
created_at DATETIME(6) NOT NULL DEFAULT NOW(6),
|
||||
updated_at DATETIME(6) NOT NULL DEFAULT NOW(6) ON UPDATE NOW(6),
|
||||
KEY idx_scim_user_emails_email_type(type, email),
|
||||
CONSTRAINT fk_scim_user_emails_scim_user_id FOREIGN KEY (scim_user_id) REFERENCES scim_users (id) ON DELETE CASCADE
|
||||
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS scim_groups (
|
||||
id int UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
external_id VARCHAR(255) NULL,
|
||||
display_name VARCHAR(255) NOT NULL,
|
||||
created_at DATETIME(6) NOT NULL DEFAULT NOW(6),
|
||||
updated_at DATETIME(6) NOT NULL DEFAULT NOW(6) ON UPDATE NOW(6),
|
||||
KEY idx_scim_groups_external_id (external_id),
|
||||
-- Entra ID requires a unique display name
|
||||
UNIQUE KEY idx_scim_groups_display_name (display_name)
|
||||
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS scim_user_group (
|
||||
scim_user_id INT UNSIGNED NOT NULL,
|
||||
group_id INT UNSIGNED NOT NULL,
|
||||
created_at DATETIME(6) NOT NULL DEFAULT NOW(6),
|
||||
PRIMARY KEY (scim_user_id, group_id),
|
||||
CONSTRAINT fk_scim_user_group_scim_user_id FOREIGN KEY (scim_user_id) REFERENCES scim_users (id) ON DELETE CASCADE,
|
||||
CONSTRAINT fk_scim_user_group_group_id FOREIGN KEY (group_id) REFERENCES scim_groups (id) ON DELETE CASCADE
|
||||
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_unicode_ci;
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create scim tables: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Down_20250331042354(tx *sql.Tx) error {
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
331
server/datastore/mysql/scim.go
Normal file
331
server/datastore/mysql/scim.go
Normal file
|
|
@ -0,0 +1,331 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// CreateScimUser creates a new SCIM user in the database
|
||||
func (ds *Datastore) CreateScimUser(ctx context.Context, user *fleet.ScimUser) (uint, error) {
|
||||
var userID uint
|
||||
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
||||
const insertUserQuery = `
|
||||
INSERT INTO scim_users (
|
||||
external_id, user_name, given_name, family_name, active
|
||||
) VALUES (?, ?, ?, ?, ?)`
|
||||
result, err := tx.ExecContext(
|
||||
ctx,
|
||||
insertUserQuery,
|
||||
user.ExternalID,
|
||||
user.UserName,
|
||||
user.GivenName,
|
||||
user.FamilyName,
|
||||
user.Active,
|
||||
)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "insert scim user")
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "insert scim user last insert id")
|
||||
}
|
||||
user.ID = uint(id) // nolint:gosec // dismiss G115
|
||||
userID = user.ID
|
||||
|
||||
return insertEmails(ctx, tx, user)
|
||||
})
|
||||
return userID, err
|
||||
}
|
||||
|
||||
// ScimUserByID retrieves a SCIM user by ID
|
||||
func (ds *Datastore) ScimUserByID(ctx context.Context, id uint) (*fleet.ScimUser, error) {
|
||||
const query = `
|
||||
SELECT
|
||||
id, external_id, user_name, given_name, family_name, active
|
||||
FROM scim_users
|
||||
WHERE id = ?
|
||||
`
|
||||
user := &fleet.ScimUser{}
|
||||
err := sqlx.GetContext(ctx, ds.reader(ctx), user, query, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, notFound("scim user").WithID(id)
|
||||
}
|
||||
return nil, ctxerr.Wrap(ctx, err, "select scim user")
|
||||
}
|
||||
|
||||
// Get the user's emails
|
||||
emails, err := ds.getScimUserEmails(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.Emails = emails
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// ScimUserByUserName retrieves a SCIM user by username
|
||||
func (ds *Datastore) ScimUserByUserName(ctx context.Context, userName string) (*fleet.ScimUser, error) {
|
||||
const query = `
|
||||
SELECT
|
||||
id, external_id, user_name, given_name, family_name, active
|
||||
FROM scim_users
|
||||
WHERE user_name = ?
|
||||
`
|
||||
user := &fleet.ScimUser{}
|
||||
err := sqlx.GetContext(ctx, ds.reader(ctx), user, query, userName)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, notFound("scim user")
|
||||
}
|
||||
return nil, ctxerr.Wrap(ctx, err, "select scim user by userName")
|
||||
}
|
||||
|
||||
// Get the user's emails
|
||||
emails, err := ds.getScimUserEmails(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.Emails = emails
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// ReplaceScimUser replaces an existing SCIM user in the database
|
||||
func (ds *Datastore) ReplaceScimUser(ctx context.Context, user *fleet.ScimUser) error {
|
||||
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
||||
// Update the SCIM user
|
||||
const updateUserQuery = `
|
||||
UPDATE scim_users SET
|
||||
external_id = ?,
|
||||
user_name = ?,
|
||||
given_name = ?,
|
||||
family_name = ?,
|
||||
active = ?
|
||||
WHERE id = ?`
|
||||
result, err := tx.ExecContext(
|
||||
ctx,
|
||||
updateUserQuery,
|
||||
user.ExternalID,
|
||||
user.UserName,
|
||||
user.GivenName,
|
||||
user.FamilyName,
|
||||
user.Active,
|
||||
user.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "update scim user")
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "get rows affected for update scim user")
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
return notFound("scim user").WithID(user.ID)
|
||||
}
|
||||
|
||||
// We assume that email is not blank/null.
|
||||
// However, we do not assume that email/type are unique for a user. To keep the code simple, we:
|
||||
// 1. Delete all existing emails
|
||||
// 2. Insert all new emails
|
||||
// This is less efficient and can be optimized if we notice a load on these tables in production.
|
||||
|
||||
const deleteEmailsQuery = `DELETE FROM scim_user_emails WHERE scim_user_id = ?`
|
||||
_, err = tx.ExecContext(ctx, deleteEmailsQuery, user.ID)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "delete scim user emails")
|
||||
}
|
||||
|
||||
return insertEmails(ctx, tx, user)
|
||||
})
|
||||
}
|
||||
|
||||
func insertEmails(ctx context.Context, tx sqlx.ExtContext, user *fleet.ScimUser) error {
|
||||
// Insert the user's emails in a single batch if any
|
||||
if len(user.Emails) > 0 {
|
||||
// Build the batch insert query
|
||||
valueStrings := make([]string, 0, len(user.Emails))
|
||||
valueArgs := make([]interface{}, 0, len(user.Emails)*4)
|
||||
|
||||
for i := range user.Emails {
|
||||
user.Emails[i].ScimUserID = user.ID
|
||||
valueStrings = append(valueStrings, "(?, ?, ?, ?)")
|
||||
valueArgs = append(valueArgs,
|
||||
user.Emails[i].ScimUserID,
|
||||
user.Emails[i].Email,
|
||||
user.Emails[i].Primary,
|
||||
user.Emails[i].Type,
|
||||
)
|
||||
}
|
||||
|
||||
// Construct the batch insert query
|
||||
insertEmailQuery := `
|
||||
INSERT INTO scim_user_emails (
|
||||
scim_user_id, email, ` + "`primary`" + `, type
|
||||
) VALUES ` + strings.Join(valueStrings, ",")
|
||||
|
||||
// Execute the batch insert
|
||||
_, err := tx.ExecContext(ctx, insertEmailQuery, valueArgs...)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "batch insert scim user emails")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteScimUser deletes a SCIM user from the database
|
||||
func (ds *Datastore) DeleteScimUser(ctx context.Context, id uint) error {
|
||||
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
||||
// Delete all email entries for the user
|
||||
const deleteEmailsQuery = `DELETE FROM scim_user_emails WHERE scim_user_id = ?`
|
||||
_, err := tx.ExecContext(ctx, deleteEmailsQuery, id)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "delete scim user emails")
|
||||
}
|
||||
|
||||
// Delete the user
|
||||
const deleteUserQuery = `DELETE FROM scim_users WHERE id = ?`
|
||||
result, err := tx.ExecContext(ctx, deleteUserQuery, id)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "delete scim user")
|
||||
}
|
||||
|
||||
// Check if the user existed
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "get rows affected for delete scim user")
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
return notFound("scim user").WithID(id)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ListScimUsers retrieves a list of SCIM users with optional filtering
|
||||
func (ds *Datastore) ListScimUsers(ctx context.Context, opts fleet.ScimUsersListOptions) (users []fleet.ScimUser, totalResults uint, err error) {
|
||||
// Default pagination values if not provided
|
||||
if opts.Page == 0 {
|
||||
opts.Page = 1
|
||||
}
|
||||
if opts.PerPage == 0 {
|
||||
opts.PerPage = 100
|
||||
}
|
||||
|
||||
// Calculate offset for pagination
|
||||
offset := (opts.Page - 1) * opts.PerPage
|
||||
|
||||
// Build the base query
|
||||
baseQuery := `
|
||||
SELECT DISTINCT
|
||||
scim_users.id, external_id, user_name, given_name, family_name, active
|
||||
FROM scim_users
|
||||
`
|
||||
|
||||
// Add joins and where clauses based on filters
|
||||
var whereClause string
|
||||
var params []interface{}
|
||||
|
||||
if opts.UserNameFilter != nil {
|
||||
// Filter by username
|
||||
whereClause = " WHERE scim_users.user_name = ?"
|
||||
params = append(params, *opts.UserNameFilter)
|
||||
} else if opts.EmailTypeFilter != nil && opts.EmailValueFilter != nil {
|
||||
// Filter by email type and value
|
||||
baseQuery += " LEFT JOIN scim_user_emails ON scim_users.id = scim_user_emails.scim_user_id"
|
||||
whereClause = " WHERE scim_user_emails.type = ? AND scim_user_emails.email = ?"
|
||||
params = append(params, *opts.EmailTypeFilter, *opts.EmailValueFilter)
|
||||
}
|
||||
|
||||
// First, get the total count without pagination
|
||||
countQuery := "SELECT COUNT(DISTINCT id) FROM (" + baseQuery + whereClause + ") AS filtered_users"
|
||||
err = sqlx.GetContext(ctx, ds.reader(ctx), &totalResults, countQuery, params...)
|
||||
if err != nil {
|
||||
return nil, 0, ctxerr.Wrap(ctx, err, "count total scim users")
|
||||
}
|
||||
|
||||
// Add pagination to the main query
|
||||
query := baseQuery + whereClause + " ORDER BY scim_users.id LIMIT ? OFFSET ?"
|
||||
params = append(params, opts.PerPage, offset)
|
||||
|
||||
// Execute the query
|
||||
err = sqlx.SelectContext(ctx, ds.reader(ctx), &users, query, params...)
|
||||
if err != nil {
|
||||
return nil, 0, ctxerr.Wrap(ctx, err, "list scim users")
|
||||
}
|
||||
|
||||
// Process the results
|
||||
userIDs := make([]uint, 0, len(users))
|
||||
userMap := make(map[uint]*fleet.ScimUser, len(users))
|
||||
|
||||
for i, user := range users {
|
||||
userIDs = append(userIDs, user.ID)
|
||||
userMap[user.ID] = &users[i]
|
||||
}
|
||||
|
||||
// If no users found, return empty slice
|
||||
if len(users) == 0 {
|
||||
return users, totalResults, nil
|
||||
}
|
||||
|
||||
// Fetch emails for all users in a single query
|
||||
emailQuery, args, err := sqlx.In(`
|
||||
SELECT
|
||||
scim_user_id, email, `+"`primary`"+`, type
|
||||
FROM scim_user_emails
|
||||
WHERE scim_user_id IN (?)
|
||||
ORDER BY email ASC
|
||||
`, userIDs)
|
||||
if err != nil {
|
||||
return nil, 0, ctxerr.Wrap(ctx, err, "prepare emails query")
|
||||
}
|
||||
|
||||
// Convert query for the specific DB dialect
|
||||
emailQuery = ds.reader(ctx).Rebind(emailQuery)
|
||||
|
||||
// Execute the email query
|
||||
var allEmails []fleet.ScimUserEmail
|
||||
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &allEmails, emailQuery, args...); err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, 0, ctxerr.Wrap(ctx, err, "select scim user emails")
|
||||
}
|
||||
}
|
||||
|
||||
// Associate emails with their users
|
||||
for i := range allEmails {
|
||||
email := allEmails[i]
|
||||
if user, ok := userMap[email.ScimUserID]; ok {
|
||||
user.Emails = append(user.Emails, email)
|
||||
}
|
||||
}
|
||||
|
||||
return users, totalResults, nil
|
||||
}
|
||||
|
||||
// getScimUserEmails retrieves all emails for a SCIM user
|
||||
func (ds *Datastore) getScimUserEmails(ctx context.Context, userID uint) ([]fleet.ScimUserEmail, error) {
|
||||
const query = `
|
||||
SELECT
|
||||
scim_user_id, email, ` + "`primary`" + `, type
|
||||
FROM scim_user_emails
|
||||
WHERE scim_user_id = ? ORDER BY email ASC
|
||||
`
|
||||
var emails []fleet.ScimUserEmail
|
||||
err := sqlx.SelectContext(ctx, ds.reader(ctx), &emails, query, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, ctxerr.Wrap(ctx, err, "select scim user emails")
|
||||
}
|
||||
return emails, nil
|
||||
}
|
||||
493
server/datastore/mysql/scim_test.go
Normal file
493
server/datastore/mysql/scim_test.go
Normal file
|
|
@ -0,0 +1,493 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestScim(t *testing.T) {
|
||||
ds := CreateMySQLDS(t)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
fn func(t *testing.T, ds *Datastore)
|
||||
}{
|
||||
{"ScimUserCreate", testScimUserCreate},
|
||||
{"ScimUserByID", testScimUserByID},
|
||||
{"ScimUserByUserName", testScimUserByUserName},
|
||||
{"ReplaceScimUser", testReplaceScimUser},
|
||||
{"DeleteScimUser", testDeleteScimUser},
|
||||
{"ListScimUsers", testListScimUsers},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
defer TruncateTables(t, ds, "scim_users", "scim_user_emails")
|
||||
c.fn(t, ds)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testScimUserCreate(t *testing.T, ds *Datastore) {
|
||||
usersToCreate := []fleet.ScimUser{
|
||||
{
|
||||
UserName: "user1",
|
||||
ExternalID: nil,
|
||||
GivenName: nil,
|
||||
FamilyName: nil,
|
||||
Active: nil,
|
||||
Emails: []fleet.ScimUserEmail{},
|
||||
},
|
||||
{
|
||||
UserName: "user2",
|
||||
ExternalID: ptr.String("ext-123"),
|
||||
GivenName: ptr.String("John"),
|
||||
FamilyName: ptr.String("Doe"),
|
||||
Active: ptr.Bool(true),
|
||||
Emails: []fleet.ScimUserEmail{
|
||||
{
|
||||
Email: "john.doe@example.com",
|
||||
Primary: ptr.Bool(true),
|
||||
Type: ptr.String("work"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
UserName: "user3",
|
||||
ExternalID: ptr.String("ext-456"),
|
||||
GivenName: ptr.String("Jane"),
|
||||
FamilyName: ptr.String("Smith"),
|
||||
Active: ptr.Bool(true),
|
||||
Emails: []fleet.ScimUserEmail{
|
||||
{
|
||||
Email: "jane.personal@example.com",
|
||||
Primary: ptr.Bool(false),
|
||||
Type: ptr.String("home"),
|
||||
},
|
||||
{
|
||||
Email: "jane.smith@example.com",
|
||||
Primary: ptr.Bool(true),
|
||||
Type: ptr.String("work"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, u := range usersToCreate {
|
||||
var err error
|
||||
userCopy := u
|
||||
userCopy.ID, err = ds.CreateScimUser(context.Background(), &u)
|
||||
assert.Nil(t, err)
|
||||
|
||||
verify, err := ds.ScimUserByUserName(context.Background(), u.UserName)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, userCopy.ID, verify.ID)
|
||||
assert.Equal(t, userCopy.UserName, verify.UserName)
|
||||
assert.Equal(t, userCopy.ExternalID, verify.ExternalID)
|
||||
assert.Equal(t, userCopy.GivenName, verify.GivenName)
|
||||
assert.Equal(t, userCopy.FamilyName, verify.FamilyName)
|
||||
assert.Equal(t, userCopy.Active, verify.Active)
|
||||
|
||||
// Verify emails
|
||||
assert.Equal(t, len(userCopy.Emails), len(verify.Emails))
|
||||
for i, email := range userCopy.Emails {
|
||||
assert.Equal(t, email.Email, verify.Emails[i].Email)
|
||||
assert.Equal(t, email.Primary, verify.Emails[i].Primary)
|
||||
assert.Equal(t, email.Type, verify.Emails[i].Type)
|
||||
assert.Equal(t, u.ID, verify.Emails[i].ScimUserID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testScimUserByID(t *testing.T, ds *Datastore) {
|
||||
users := createTestScimUsers(t, ds)
|
||||
for _, tt := range users {
|
||||
returned, err := ds.ScimUserByID(context.Background(), tt.ID)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, tt.ID, returned.ID)
|
||||
assert.Equal(t, tt.UserName, returned.UserName)
|
||||
assert.Equal(t, tt.ExternalID, returned.ExternalID)
|
||||
assert.Equal(t, tt.GivenName, returned.GivenName)
|
||||
assert.Equal(t, tt.FamilyName, returned.FamilyName)
|
||||
assert.Equal(t, tt.Active, returned.Active)
|
||||
|
||||
// Verify emails
|
||||
assert.Equal(t, len(tt.Emails), len(returned.Emails))
|
||||
for i, email := range tt.Emails {
|
||||
assert.Equal(t, email.Email, returned.Emails[i].Email)
|
||||
assert.Equal(t, email.Primary, returned.Emails[i].Primary)
|
||||
assert.Equal(t, email.Type, returned.Emails[i].Type)
|
||||
assert.Equal(t, tt.ID, returned.Emails[i].ScimUserID)
|
||||
}
|
||||
}
|
||||
|
||||
// test missing user
|
||||
_, err := ds.ScimUserByID(context.Background(), 10000000000)
|
||||
assert.True(t, fleet.IsNotFound(err))
|
||||
}
|
||||
|
||||
func testScimUserByUserName(t *testing.T, ds *Datastore) {
|
||||
users := createTestScimUsers(t, ds)
|
||||
for _, tt := range users {
|
||||
returned, err := ds.ScimUserByUserName(context.Background(), tt.UserName)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, tt.ID, returned.ID)
|
||||
assert.Equal(t, tt.UserName, returned.UserName)
|
||||
assert.Equal(t, tt.ExternalID, returned.ExternalID)
|
||||
assert.Equal(t, tt.GivenName, returned.GivenName)
|
||||
assert.Equal(t, tt.FamilyName, returned.FamilyName)
|
||||
assert.Equal(t, tt.Active, returned.Active)
|
||||
|
||||
// Verify emails
|
||||
assert.Equal(t, len(tt.Emails), len(returned.Emails))
|
||||
for i, email := range tt.Emails {
|
||||
assert.Equal(t, email.Email, returned.Emails[i].Email)
|
||||
assert.Equal(t, email.Primary, returned.Emails[i].Primary)
|
||||
assert.Equal(t, email.Type, returned.Emails[i].Type)
|
||||
assert.Equal(t, tt.ID, returned.Emails[i].ScimUserID)
|
||||
}
|
||||
}
|
||||
|
||||
// test missing user
|
||||
_, err := ds.ScimUserByUserName(context.Background(), "nonexistent-user")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func createTestScimUsers(t *testing.T, ds *Datastore) []*fleet.ScimUser {
|
||||
createUsers := []fleet.ScimUser{
|
||||
{
|
||||
UserName: "test-user1",
|
||||
ExternalID: ptr.String("ext-test-123"),
|
||||
GivenName: ptr.String("Test"),
|
||||
FamilyName: ptr.String("User"),
|
||||
Active: ptr.Bool(true),
|
||||
Emails: []fleet.ScimUserEmail{
|
||||
{
|
||||
Email: "test.user@example.com",
|
||||
Primary: ptr.Bool(true),
|
||||
Type: ptr.String("work"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
UserName: "test-user2",
|
||||
ExternalID: ptr.String("ext-test-456"),
|
||||
GivenName: ptr.String("Another"),
|
||||
FamilyName: ptr.String("User"),
|
||||
Active: ptr.Bool(true),
|
||||
Emails: []fleet.ScimUserEmail{
|
||||
{
|
||||
Email: "another.personal@example.com",
|
||||
Primary: ptr.Bool(false),
|
||||
Type: ptr.String("home"),
|
||||
},
|
||||
{
|
||||
Email: "another.user@example.com",
|
||||
Primary: ptr.Bool(true),
|
||||
Type: ptr.String("work"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var users []*fleet.ScimUser
|
||||
for _, u := range createUsers {
|
||||
var err error
|
||||
u.ID, err = ds.CreateScimUser(context.Background(), &u)
|
||||
require.Nil(t, err)
|
||||
users = append(users, &u)
|
||||
}
|
||||
return users
|
||||
}
|
||||
|
||||
func testReplaceScimUser(t *testing.T, ds *Datastore) {
|
||||
// Create a test user
|
||||
user := fleet.ScimUser{
|
||||
UserName: "replace-test-user",
|
||||
ExternalID: ptr.String("ext-replace-123"),
|
||||
GivenName: ptr.String("Original"),
|
||||
FamilyName: ptr.String("User"),
|
||||
Active: ptr.Bool(true),
|
||||
Emails: []fleet.ScimUserEmail{
|
||||
{
|
||||
Email: "original.user@example.com",
|
||||
Primary: ptr.Bool(true),
|
||||
Type: ptr.String("work"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var err error
|
||||
user.ID, err = ds.CreateScimUser(context.Background(), &user)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Verify the user was created correctly
|
||||
createdUser, err := ds.ScimUserByID(context.Background(), user.ID)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, user.UserName, createdUser.UserName)
|
||||
assert.Equal(t, user.ExternalID, createdUser.ExternalID)
|
||||
assert.Equal(t, user.GivenName, createdUser.GivenName)
|
||||
assert.Equal(t, user.FamilyName, createdUser.FamilyName)
|
||||
assert.Equal(t, user.Active, createdUser.Active)
|
||||
assert.Equal(t, 1, len(createdUser.Emails))
|
||||
assert.Equal(t, "original.user@example.com", createdUser.Emails[0].Email)
|
||||
|
||||
// Modify the user
|
||||
updatedUser := fleet.ScimUser{
|
||||
ID: user.ID,
|
||||
UserName: "replace-test-user", // Same username
|
||||
ExternalID: ptr.String("ext-replace-456"), // Changed external ID
|
||||
GivenName: ptr.String("Updated"), // Changed given name
|
||||
FamilyName: ptr.String("User"), // Same family name
|
||||
Active: ptr.Bool(false), // Changed active status
|
||||
Emails: []fleet.ScimUserEmail{ // Changed emails
|
||||
{
|
||||
Email: "updated.user@example.com",
|
||||
Primary: ptr.Bool(true),
|
||||
Type: ptr.String("work"),
|
||||
},
|
||||
{
|
||||
Email: "personal.user@example.com",
|
||||
Primary: ptr.Bool(false),
|
||||
Type: ptr.String("home"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Replace the user
|
||||
err = ds.ReplaceScimUser(context.Background(), &updatedUser)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Verify the user was updated correctly
|
||||
replacedUser, err := ds.ScimUserByID(context.Background(), user.ID)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, updatedUser.UserName, replacedUser.UserName)
|
||||
assert.Equal(t, updatedUser.ExternalID, replacedUser.ExternalID)
|
||||
assert.Equal(t, updatedUser.GivenName, replacedUser.GivenName)
|
||||
assert.Equal(t, updatedUser.FamilyName, replacedUser.FamilyName)
|
||||
assert.Equal(t, updatedUser.Active, replacedUser.Active)
|
||||
|
||||
// Verify emails were replaced
|
||||
assert.Equal(t, 2, len(replacedUser.Emails))
|
||||
assert.Equal(t, "personal.user@example.com", replacedUser.Emails[0].Email) // Alphabetical order
|
||||
assert.Equal(t, "updated.user@example.com", replacedUser.Emails[1].Email)
|
||||
|
||||
// Test replacing a non-existent user
|
||||
nonExistentUser := fleet.ScimUser{
|
||||
ID: 99999, // Non-existent ID
|
||||
UserName: "non-existent",
|
||||
ExternalID: ptr.String("ext-non-existent"),
|
||||
GivenName: ptr.String("Non"),
|
||||
FamilyName: ptr.String("Existent"),
|
||||
Active: ptr.Bool(true),
|
||||
}
|
||||
|
||||
err = ds.ReplaceScimUser(context.Background(), &nonExistentUser)
|
||||
assert.True(t, fleet.IsNotFound(err))
|
||||
}
|
||||
|
||||
func testDeleteScimUser(t *testing.T, ds *Datastore) {
|
||||
// Create a test user
|
||||
user := fleet.ScimUser{
|
||||
UserName: "delete-test-user",
|
||||
ExternalID: ptr.String("ext-delete-123"),
|
||||
GivenName: ptr.String("Delete"),
|
||||
FamilyName: ptr.String("User"),
|
||||
Active: ptr.Bool(true),
|
||||
Emails: []fleet.ScimUserEmail{
|
||||
{
|
||||
Email: "delete.user@example.com",
|
||||
Primary: ptr.Bool(true),
|
||||
Type: ptr.String("work"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var err error
|
||||
user.ID, err = ds.CreateScimUser(context.Background(), &user)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Verify the user was created correctly
|
||||
createdUser, err := ds.ScimUserByID(context.Background(), user.ID)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, user.UserName, createdUser.UserName)
|
||||
|
||||
// Delete the user
|
||||
err = ds.DeleteScimUser(context.Background(), user.ID)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Verify the user was deleted
|
||||
_, err = ds.ScimUserByID(context.Background(), user.ID)
|
||||
assert.True(t, fleet.IsNotFound(err))
|
||||
|
||||
// Test deleting a non-existent user
|
||||
err = ds.DeleteScimUser(context.Background(), 99999) // Non-existent ID
|
||||
assert.True(t, fleet.IsNotFound(err))
|
||||
}
|
||||
|
||||
func testListScimUsers(t *testing.T, ds *Datastore) {
|
||||
// Create test users with different attributes and emails
|
||||
users := []fleet.ScimUser{
|
||||
{
|
||||
UserName: "list-test-user1",
|
||||
ExternalID: ptr.String("ext-list-123"),
|
||||
GivenName: ptr.String("List"),
|
||||
FamilyName: ptr.String("User1"),
|
||||
Active: ptr.Bool(true),
|
||||
Emails: []fleet.ScimUserEmail{
|
||||
{
|
||||
Email: "list.user1@example.com",
|
||||
Primary: ptr.Bool(true),
|
||||
Type: ptr.String("work"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
UserName: "list-test-user2",
|
||||
ExternalID: ptr.String("ext-list-456"),
|
||||
GivenName: ptr.String("List"),
|
||||
FamilyName: ptr.String("User2"),
|
||||
Active: ptr.Bool(true),
|
||||
Emails: []fleet.ScimUserEmail{
|
||||
{
|
||||
Email: "list.user2@example.com",
|
||||
Primary: ptr.Bool(true),
|
||||
Type: ptr.String("work"),
|
||||
},
|
||||
{
|
||||
Email: "personal.user2@example.com",
|
||||
Primary: ptr.Bool(false),
|
||||
Type: ptr.String("home"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
UserName: "different-user3",
|
||||
ExternalID: ptr.String("ext-list-789"),
|
||||
GivenName: ptr.String("Different"),
|
||||
FamilyName: ptr.String("User3"),
|
||||
Active: ptr.Bool(false),
|
||||
Emails: []fleet.ScimUserEmail{
|
||||
{
|
||||
Email: "different.user3@example.com",
|
||||
Primary: ptr.Bool(true),
|
||||
Type: ptr.String("work"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create the users
|
||||
for i := range users {
|
||||
var err error
|
||||
users[i].ID, err = ds.CreateScimUser(context.Background(), &users[i])
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
// Test 1: List all users without filters
|
||||
allUsers, totalResults, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
|
||||
Page: 1,
|
||||
PerPage: 10,
|
||||
})
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 3, len(allUsers))
|
||||
assert.Equal(t, uint(3), totalResults)
|
||||
|
||||
// Verify that our test users are in the results
|
||||
foundUsers := 0
|
||||
for _, u := range allUsers {
|
||||
for _, testUser := range users {
|
||||
if u.ID == testUser.ID {
|
||||
foundUsers++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 3, foundUsers)
|
||||
|
||||
// Test 2: Pagination - first page with 2 items
|
||||
page1Users, totalPage1, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
|
||||
Page: 1,
|
||||
PerPage: 2,
|
||||
})
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 2, len(page1Users))
|
||||
assert.Equal(t, uint(3), totalPage1) // Total should still be 3
|
||||
|
||||
// Test 3: Pagination - second page with 2 items
|
||||
page2Users, totalPage2, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
|
||||
Page: 2,
|
||||
PerPage: 2,
|
||||
})
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 1, len(page2Users))
|
||||
assert.Equal(t, uint(3), totalPage2) // Total should still be 3
|
||||
|
||||
// Verify that page1 and page2 contain different users
|
||||
for _, p1User := range page1Users {
|
||||
for _, p2User := range page2Users {
|
||||
assert.NotEqual(t, p1User.ID, p2User.ID, "Users should not appear on multiple pages")
|
||||
}
|
||||
}
|
||||
|
||||
// Test 4: Filter by username
|
||||
listUsers, totalListUsers, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
|
||||
Page: 1,
|
||||
PerPage: 10,
|
||||
UserNameFilter: ptr.String("list-test-user2"),
|
||||
})
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Len(t, listUsers, 1)
|
||||
assert.Equal(t, uint(1), totalListUsers)
|
||||
assert.Equal(t, "list-test-user2", listUsers[0].UserName)
|
||||
|
||||
// Test 5: Filter by email type and value
|
||||
homeEmailUsers, totalHomeEmailUsers, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
|
||||
Page: 1,
|
||||
PerPage: 10,
|
||||
EmailTypeFilter: ptr.String("home"),
|
||||
EmailValueFilter: ptr.String("personal.user2@example.com"),
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.Len(t, homeEmailUsers, 1)
|
||||
assert.Equal(t, uint(1), totalHomeEmailUsers)
|
||||
assert.Equal(t, users[1].ID, homeEmailUsers[0].ID)
|
||||
assert.Equal(t, 2, len(homeEmailUsers[0].Emails))
|
||||
|
||||
// Test 6: Filter by email type and value - work emails
|
||||
workEmailUsers, totalWorkEmailUsers, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
|
||||
Page: 1,
|
||||
PerPage: 10,
|
||||
EmailTypeFilter: ptr.String("work"),
|
||||
EmailValueFilter: ptr.String("different.user3@example.com"),
|
||||
})
|
||||
require.Nil(t, err)
|
||||
assert.Len(t, workEmailUsers, 1)
|
||||
assert.Equal(t, uint(1), totalWorkEmailUsers)
|
||||
|
||||
// Test 7: No results for non-matching filters
|
||||
noUsers, totalNoUsers1, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
|
||||
Page: 1,
|
||||
PerPage: 10,
|
||||
UserNameFilter: ptr.String("nonexistent"),
|
||||
})
|
||||
require.Nil(t, err)
|
||||
assert.Empty(t, noUsers)
|
||||
assert.Equal(t, uint(0), totalNoUsers1)
|
||||
|
||||
noUsers, totalNoUsers2, err := ds.ListScimUsers(context.Background(), fleet.ScimUsersListOptions{
|
||||
Page: 1,
|
||||
PerPage: 10,
|
||||
EmailTypeFilter: ptr.String("nonexistent"),
|
||||
EmailValueFilter: ptr.String("nonexistent"),
|
||||
})
|
||||
require.Nil(t, err)
|
||||
assert.Empty(t, noUsers)
|
||||
assert.Equal(t, uint(0), totalNoUsers2)
|
||||
}
|
||||
|
|
@ -2012,6 +2012,22 @@ type Datastore interface {
|
|||
// Android
|
||||
|
||||
AndroidDatastore
|
||||
|
||||
// /////////////////////////////////////////////////////////////////////////////
|
||||
// SCIM
|
||||
|
||||
// CreateScimUser creates a new SCIM user in the database
|
||||
CreateScimUser(ctx context.Context, user *ScimUser) (uint, error)
|
||||
// ScimUserByID retrieves a SCIM user by ID
|
||||
ScimUserByID(ctx context.Context, id uint) (*ScimUser, error)
|
||||
// ScimUserByUserName retrieves a SCIM user by username
|
||||
ScimUserByUserName(ctx context.Context, userName string) (*ScimUser, error)
|
||||
// ReplaceScimUser replaces an existing SCIM user in the database
|
||||
ReplaceScimUser(ctx context.Context, user *ScimUser) error
|
||||
// DeleteScimUser deletes a SCIM user from the database
|
||||
DeleteScimUser(ctx context.Context, id uint) error
|
||||
// ListScimUsers retrieves a list of SCIM users with optional filtering
|
||||
ListScimUsers(ctx context.Context, opts ScimUsersListOptions) (users []ScimUser, totalResults uint, err error)
|
||||
}
|
||||
|
||||
type AndroidDatastore interface {
|
||||
|
|
|
|||
41
server/fleet/scim.go
Normal file
41
server/fleet/scim.go
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
package fleet
|
||||
|
||||
// ScimUser represents a SCIM user in the database
|
||||
type ScimUser struct {
|
||||
ID uint `db:"id"`
|
||||
ExternalID *string `db:"external_id"`
|
||||
UserName string `db:"user_name"`
|
||||
GivenName *string `db:"given_name"`
|
||||
FamilyName *string `db:"family_name"`
|
||||
Active *bool `db:"active"`
|
||||
Emails []ScimUserEmail
|
||||
}
|
||||
|
||||
func (su *ScimUser) AuthzType() string {
|
||||
return "scim_user"
|
||||
}
|
||||
|
||||
// ScimUserEmail represents an email address associated with a SCIM user
|
||||
type ScimUserEmail struct {
|
||||
ScimUserID uint `db:"scim_user_id"`
|
||||
Email string `db:"email"`
|
||||
Primary *bool `db:"primary"`
|
||||
Type *string `db:"type"`
|
||||
}
|
||||
|
||||
type ScimUsersListOptions struct {
|
||||
// Which page to return (must be positive integer)
|
||||
Page uint
|
||||
// How many results per page (must be positive integer)
|
||||
PerPage uint
|
||||
|
||||
// UserNameFilter filters by userName -- max of 1 response is expected
|
||||
// Cannot be used with other filters.
|
||||
UserNameFilter *string
|
||||
|
||||
// EmailTypeFilter and EmailValueFilter are needed to support Entra ID filter: emails[type eq "work"].value eq "user@contoso.com"
|
||||
// https://learn.microsoft.com/en-us/entra/identity/app-provisioning/use-scim-to-provision-users-and-groups#users
|
||||
// Cannot be used with other filters.
|
||||
EmailTypeFilter *string
|
||||
EmailValueFilter *string
|
||||
}
|
||||
|
|
@ -1284,6 +1284,18 @@ type SetAndroidEnabledAndConfiguredFunc func(ctx context.Context, configured boo
|
|||
|
||||
type UpdateAndroidHostFunc func(ctx context.Context, host *fleet.AndroidHost, fromEnroll bool) error
|
||||
|
||||
type CreateScimUserFunc func(ctx context.Context, user *fleet.ScimUser) (uint, error)
|
||||
|
||||
type ScimUserByIDFunc func(ctx context.Context, id uint) (*fleet.ScimUser, error)
|
||||
|
||||
type ScimUserByUserNameFunc func(ctx context.Context, userName string) (*fleet.ScimUser, error)
|
||||
|
||||
type ReplaceScimUserFunc func(ctx context.Context, user *fleet.ScimUser) error
|
||||
|
||||
type DeleteScimUserFunc func(ctx context.Context, id uint) error
|
||||
|
||||
type ListScimUsersFunc func(ctx context.Context, opts fleet.ScimUsersListOptions) (users []fleet.ScimUser, totalResults uint, err error)
|
||||
|
||||
type DataStore struct {
|
||||
HealthCheckFunc HealthCheckFunc
|
||||
HealthCheckFuncInvoked bool
|
||||
|
|
@ -3178,6 +3190,24 @@ type DataStore struct {
|
|||
UpdateAndroidHostFunc UpdateAndroidHostFunc
|
||||
UpdateAndroidHostFuncInvoked bool
|
||||
|
||||
CreateScimUserFunc CreateScimUserFunc
|
||||
CreateScimUserFuncInvoked bool
|
||||
|
||||
ScimUserByIDFunc ScimUserByIDFunc
|
||||
ScimUserByIDFuncInvoked bool
|
||||
|
||||
ScimUserByUserNameFunc ScimUserByUserNameFunc
|
||||
ScimUserByUserNameFuncInvoked bool
|
||||
|
||||
ReplaceScimUserFunc ReplaceScimUserFunc
|
||||
ReplaceScimUserFuncInvoked bool
|
||||
|
||||
DeleteScimUserFunc DeleteScimUserFunc
|
||||
DeleteScimUserFuncInvoked bool
|
||||
|
||||
ListScimUsersFunc ListScimUsersFunc
|
||||
ListScimUsersFuncInvoked bool
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
|
|
@ -7597,3 +7627,45 @@ func (s *DataStore) UpdateAndroidHost(ctx context.Context, host *fleet.AndroidHo
|
|||
s.mu.Unlock()
|
||||
return s.UpdateAndroidHostFunc(ctx, host, fromEnroll)
|
||||
}
|
||||
|
||||
func (s *DataStore) CreateScimUser(ctx context.Context, user *fleet.ScimUser) (uint, error) {
|
||||
s.mu.Lock()
|
||||
s.CreateScimUserFuncInvoked = true
|
||||
s.mu.Unlock()
|
||||
return s.CreateScimUserFunc(ctx, user)
|
||||
}
|
||||
|
||||
func (s *DataStore) ScimUserByID(ctx context.Context, id uint) (*fleet.ScimUser, error) {
|
||||
s.mu.Lock()
|
||||
s.ScimUserByIDFuncInvoked = true
|
||||
s.mu.Unlock()
|
||||
return s.ScimUserByIDFunc(ctx, id)
|
||||
}
|
||||
|
||||
func (s *DataStore) ScimUserByUserName(ctx context.Context, userName string) (*fleet.ScimUser, error) {
|
||||
s.mu.Lock()
|
||||
s.ScimUserByUserNameFuncInvoked = true
|
||||
s.mu.Unlock()
|
||||
return s.ScimUserByUserNameFunc(ctx, userName)
|
||||
}
|
||||
|
||||
func (s *DataStore) ReplaceScimUser(ctx context.Context, user *fleet.ScimUser) error {
|
||||
s.mu.Lock()
|
||||
s.ReplaceScimUserFuncInvoked = true
|
||||
s.mu.Unlock()
|
||||
return s.ReplaceScimUserFunc(ctx, user)
|
||||
}
|
||||
|
||||
func (s *DataStore) DeleteScimUser(ctx context.Context, id uint) error {
|
||||
s.mu.Lock()
|
||||
s.DeleteScimUserFuncInvoked = true
|
||||
s.mu.Unlock()
|
||||
return s.DeleteScimUserFunc(ctx, id)
|
||||
}
|
||||
|
||||
func (s *DataStore) ListScimUsers(ctx context.Context, opts fleet.ScimUsersListOptions) (users []fleet.ScimUser, totalResults uint, err error) {
|
||||
s.mu.Lock()
|
||||
s.ListScimUsersFuncInvoked = true
|
||||
s.mu.Unlock()
|
||||
return s.ListScimUsersFunc(ctx, opts)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package auth
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/authz"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/token"
|
||||
|
|
@ -39,7 +40,7 @@ func AuthenticatedUser(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpo
|
|||
return next(ctx, request)
|
||||
}
|
||||
|
||||
// if not succesful, try again this time with errors
|
||||
// if not successful, try again this time with errors
|
||||
sessionKey, ok := token.FromContext(ctx)
|
||||
if !ok {
|
||||
return nil, fleet.NewAuthHeaderRequiredError("no auth token")
|
||||
|
|
@ -67,3 +68,44 @@ func AuthenticatedUser(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpo
|
|||
func UnauthenticatedRequest(_ fleet.Service, next endpoint.Endpoint) endpoint.Endpoint {
|
||||
return log.Logged(next)
|
||||
}
|
||||
|
||||
// errorHandler has the same signature as http.Error
|
||||
type errorHandler func(w http.ResponseWriter, detail string, status int)
|
||||
|
||||
func AuthenticatedUserMiddleware(svc fleet.Service, errHandler errorHandler, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// first check if already successfully set
|
||||
if v, ok := viewer.FromContext(r.Context()); ok {
|
||||
if v.User.IsAdminForcedPasswordReset() {
|
||||
errHandler(w, fleet.ErrPasswordResetRequired.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// if not successful, try again this time with errors
|
||||
sessionKey, ok := token.FromContext(r.Context())
|
||||
if !ok {
|
||||
errHandler(w, fleet.NewAuthHeaderRequiredError("no auth token").Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
v, err := AuthViewer(r.Context(), string(sessionKey), svc)
|
||||
if err != nil {
|
||||
errHandler(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if v.User.IsAdminForcedPasswordReset() {
|
||||
errHandler(w, fleet.ErrPasswordResetRequired.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := viewer.NewContext(r.Context(), *v)
|
||||
if ac, ok := authz.FromContext(r.Context()); ok {
|
||||
ac.SetAuthnMethod(authz.AuthnUserToken)
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,3 +28,11 @@ func SetRequestsContexts(svc fleet.Service) kithttp.RequestFunc {
|
|||
return ctx
|
||||
}
|
||||
}
|
||||
|
||||
func SetRequestsContextMiddleware(svc fleet.Service, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := kithttp.PopulateRequestContext(r.Context(), r)
|
||||
ctx = SetRequestsContexts(svc)(ctx, r)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,3 +37,10 @@ func LogRequestEnd(logger kitlog.Logger) func(context.Context, http.ResponseWrit
|
|||
return ctx
|
||||
}
|
||||
}
|
||||
|
||||
func LogResponseEndMiddleware(logger kitlog.Logger, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(w, r)
|
||||
LogRequestEnd(logger)(r.Context(), w)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue