diff --git a/changes/37192-refactor-endpoint_utils b/changes/37192-refactor-endpoint_utils new file mode 100644 index 0000000000..73b1f54822 --- /dev/null +++ b/changes/37192-refactor-endpoint_utils @@ -0,0 +1 @@ +Refactored common endpoint_utils package to support bounded contexts inside Fleet codebase. \ No newline at end of file diff --git a/server/archtest/archtest.go b/server/archtest/archtest.go index da73b919d9..fd17e5b9f6 100644 --- a/server/archtest/archtest.go +++ b/server/archtest/archtest.go @@ -30,8 +30,11 @@ type PackageTest struct { forbiddenPkgs []string } +// ModuleName is the module name for the fleet project. +const ModuleName = "github.com/fleetdm/fleet/v4" + // PackageTest will ignore dependency on this package. -const thisPackage = "github.com/fleetdm/fleet/v4/server/archtest" +const thisPackage = ModuleName + "/server/archtest" type TestingT interface { Errorf(format string, args ...any) diff --git a/server/authz/errors.go b/server/authz/errors.go index 477b99976d..b755fa7acb 100644 --- a/server/authz/errors.go +++ b/server/authz/errors.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/fleetdm/fleet/v4/server/fleet" + platform_http "github.com/fleetdm/fleet/v4/server/platform/http" ) const ( @@ -66,27 +67,15 @@ func (e *Forbidden) LogFields() []interface{} { // CheckMissing is the error to return when no authorization check was performed // by the service. -type CheckMissing struct { - response interface{} - - fleet.ErrorWithUUID -} +// +// Deprecated: Use platform_http.CheckMissing instead. This alias is kept for +// backward compatibility. +type CheckMissing = platform_http.CheckMissing // CheckMissingWithResponse creates a new error indicating the authorization // check was missed, and including the response for further analysis by the error // encoder. -func CheckMissingWithResponse(response interface{}) *CheckMissing { - return &CheckMissing{response: response} -} - -func (e *CheckMissing) Error() string { - return ForbiddenErrorMessage -} - -func (e *CheckMissing) Internal() string { - return "Missing authorization check" -} - -func (e *CheckMissing) Response() interface{} { - return e.response -} +// +// Deprecated: Use platform_http.CheckMissingWithResponse instead. This alias is +// kept for backward compatibility. +var CheckMissingWithResponse = platform_http.CheckMissingWithResponse diff --git a/server/contexts/ctxerr/ctxerr.go b/server/contexts/ctxerr/ctxerr.go index a267687984..e3526fdcfd 100644 --- a/server/contexts/ctxerr/ctxerr.go +++ b/server/contexts/ctxerr/ctxerr.go @@ -16,13 +16,12 @@ import ( "encoding/json" "errors" "fmt" + "maps" "runtime" "strings" "time" - "github.com/fleetdm/fleet/v4/server/contexts/host" - "github.com/fleetdm/fleet/v4/server/contexts/viewer" - "github.com/fleetdm/fleet/v4/server/fleet" + platform_http "github.com/fleetdm/fleet/v4/server/platform/http" "github.com/getsentry/sentry-go" "go.elastic.co/apm/v2" "go.opentelemetry.io/otel/attribute" @@ -121,21 +120,9 @@ func setMetadata(ctx context.Context, data map[string]interface{}) map[string]in data["timestamp"] = nowFn().Format(time.RFC3339) - if h, ok := host.FromContext(ctx); ok { - data["host"] = map[string]interface{}{ - "platform": h.Platform, - "osquery_version": h.OsqueryVersion, - } - } - - if v, ok := viewer.FromContext(ctx); ok { - vdata := map[string]interface{}{} - data["viewer"] = vdata - vdata["is_logged_in"] = v.IsLoggedIn() - - if v.User != nil { - vdata["sso_enabled"] = v.User.SSOEnabled - } + // Get diagnostic context from all registered providers + for _, provider := range getErrorContextProviders(ctx) { + maps.Copy(data, provider.GetDiagnosticContext()) } return data @@ -215,7 +202,7 @@ func Wrapf(ctx context.Context, cause error, format string, args ...interface{}) // Cause returns the root error in err's chain. func Cause(err error) error { - return fleet.Cause(err) + return platform_http.Cause(err) } // FleetCause is similar to Cause, but returns the root-most @@ -319,6 +306,9 @@ func Handle(ctx context.Context, err error) { cause = rootCause } + // Collect telemetry context from registered providers + telemetryAttrs := collectTelemetryContext(ctx) + // send to OpenTelemetry if there's an active span if span := trace.SpanFromContext(ctx); span != nil && span.IsRecording() { // Mark the current span as failed by setting the error status. @@ -333,20 +323,25 @@ func Handle(ctx context.Context, err error) { attribute.String("exception.stacktrace", strings.Join(cause.Stack(), "\n")), } - // Add contextual information if available (same as Sentry) - v, _ := viewer.FromContext(ctx) - h, _ := host.FromContext(ctx) - - if v.User != nil { - attrs = append(attrs, - // Not sending the email here as it may contain sensitive information (PII). - attribute.Int64("user.id", int64(v.User.ID)), //nolint:gosec - ) - } else if h != nil { - attrs = append(attrs, - attribute.String("host.hostname", h.Hostname), - attribute.Int64("host.id", int64(h.ID)), //nolint:gosec - ) + // Add contextual information from telemetry providers. + // OpenTelemetry requires typed attributes, so we convert the values to the appropriate type. + for k, v := range telemetryAttrs { + switch val := v.(type) { + case string: + attrs = append(attrs, attribute.String(k, val)) + case int: + attrs = append(attrs, attribute.Int64(k, int64(val))) + case int64: + attrs = append(attrs, attribute.Int64(k, val)) + case uint: + attrs = append(attrs, attribute.Int64(k, int64(val))) //nolint:gosec + case uint64: + attrs = append(attrs, attribute.Int64(k, int64(val))) //nolint:gosec + case bool: + attrs = append(attrs, attribute.Bool(k, val)) + default: + attrs = append(attrs, attribute.String(k, fmt.Sprint(val))) + } } span.AddEvent("exception", trace.WithAttributes(attrs...)) @@ -357,25 +352,14 @@ func Handle(ctx context.Context, err error) { // if Sentry is configured, capture the error there if sentryClient := sentry.CurrentHub().Client(); sentryClient != nil { - // sentry is configured, add contextual information if available - v, _ := viewer.FromContext(ctx) - h, _ := host.FromContext(ctx) - - if v.User != nil || h != nil { - // we have a viewer (user) or a host in the context, use this to - // enrich the error with more context + if len(telemetryAttrs) > 0 { + // we have contextual information, use it to enrich the error ctxHub := sentry.CurrentHub().Clone() - if v.User != nil { - ctxHub.ConfigureScope(func(scope *sentry.Scope) { - scope.SetTag("email", v.User.Email) - scope.SetTag("user_id", fmt.Sprint(v.User.ID)) - }) - } else if h != nil { - ctxHub.ConfigureScope(func(scope *sentry.Scope) { - scope.SetTag("hostname", h.Hostname) - scope.SetTag("host_id", fmt.Sprint(h.ID)) - }) - } + ctxHub.ConfigureScope(func(scope *sentry.Scope) { + for k, v := range telemetryAttrs { + scope.SetTag(k, fmt.Sprint(v)) + } + }) ctxHub.CaptureException(cause) } else { sentry.CaptureException(cause) @@ -387,6 +371,17 @@ func Handle(ctx context.Context, err error) { } } +// collectTelemetryContext gathers telemetry context from all registered providers. +func collectTelemetryContext(ctx context.Context) map[string]any { + attrs := make(map[string]any) + for _, provider := range getErrorContextProviders(ctx) { + if telemetry := provider.GetTelemetryContext(); telemetry != nil { + maps.Copy(attrs, telemetry) + } + } + return attrs +} + // Retrieve retrieves an error from the registered error handler func Retrieve(ctx context.Context) ([]*StoredError, error) { eh := FromContext(ctx) diff --git a/server/contexts/ctxerr/ctxerr_otel_test.go b/server/contexts/ctxerr/ctxerr_otel_test.go index 85d7c803c6..d9722c33ef 100644 --- a/server/contexts/ctxerr/ctxerr_otel_test.go +++ b/server/contexts/ctxerr/ctxerr_otel_test.go @@ -29,7 +29,11 @@ func TestHandleSendsContextToOTEL(t *testing.T) { ID: 123, Email: "test@example.com", } - return viewer.NewContext(ctx, viewer.Viewer{User: testUser}) + v := viewer.Viewer{User: testUser} + ctx = viewer.NewContext(ctx, v) + // Register the viewer as an error context provider + ctx = AddErrorContextProvider(ctx, &v) + return ctx }, errorMessage: "test error with user context", expectedAttrs: map[string]any{ @@ -43,7 +47,10 @@ func TestHandleSendsContextToOTEL(t *testing.T) { ID: 456, Hostname: "test-host.example.com", } - return host.NewContext(ctx, testHost) + ctx = host.NewContext(ctx, testHost) + // Register the host as an error context provider + ctx = AddErrorContextProvider(ctx, &host.HostAttributeProvider{Host: testHost}) + return ctx }, errorMessage: "test error with host context", expectedAttrs: map[string]any{ diff --git a/server/contexts/ctxerr/ctxerr_test.go b/server/contexts/ctxerr/ctxerr_test.go index 096e3ce160..97438d523a 100644 --- a/server/contexts/ctxerr/ctxerr_test.go +++ b/server/contexts/ctxerr/ctxerr_test.go @@ -322,7 +322,10 @@ func TestAdditionalMetadata(t *testing.T) { t.Run("saves additional data about the host if present", func(t *testing.T) { ctx, cleanup := setup() defer cleanup() - hctx := host.NewContext(ctx, &fleet.Host{Platform: "test_platform", OsqueryVersion: "5.0"}) + h := &fleet.Host{Platform: "test_platform", OsqueryVersion: "5.0"} + hctx := host.NewContext(ctx, h) + // Register the host as an error context provider + hctx = AddErrorContextProvider(hctx, &host.HostAttributeProvider{Host: h}) err := New(hctx, "with host context").(*FleetError) require.JSONEq(t, string(err.data), `{"host":{"osquery_version":"5.0","platform":"test_platform"},"timestamp":"1969-06-19T21:44:05Z"}`) @@ -331,7 +334,10 @@ func TestAdditionalMetadata(t *testing.T) { t.Run("saves additional data about the viewer if present", func(t *testing.T) { ctx, cleanup := setup() defer cleanup() - vctx := viewer.NewContext(ctx, viewer.Viewer{Session: &fleet.Session{ID: 1}, User: &fleet.User{SSOEnabled: true}}) + v := viewer.Viewer{Session: &fleet.Session{ID: 1}, User: &fleet.User{SSOEnabled: true}} + vctx := viewer.NewContext(ctx, v) + // Register the viewer as an error context provider + vctx = AddErrorContextProvider(vctx, &v) err := New(vctx, "with host context").(*FleetError) require.JSONEq(t, string(err.data), `{"viewer":{"is_logged_in":true,"sso_enabled":true},"timestamp":"1969-06-19T21:44:05Z"}`) diff --git a/server/contexts/ctxerr/metadata.go b/server/contexts/ctxerr/metadata.go new file mode 100644 index 0000000000..5df6b32be7 --- /dev/null +++ b/server/contexts/ctxerr/metadata.go @@ -0,0 +1,35 @@ +package ctxerr + +import "context" + +// ErrorContextProvider provides contextual information for error handling. +// Implementations can provide data for both error storage and telemetry systems. +type ErrorContextProvider interface { + // GetDiagnosticContext returns attributes stored with errors for troubleshooting. + // Data is persisted to Redis and included in logs. Should contain diagnostic + // information like platform, versions, and status flags. Avoid including PII. + GetDiagnosticContext() map[string]any + + // GetTelemetryContext returns attributes sent to observability systems + // (OpenTelemetry, Sentry). May include identifiers not stored with errors. + // Return nil if no telemetry context is available. + GetTelemetryContext() map[string]any +} + +type errorContextProvidersKey struct{} + +// AddErrorContextProvider returns a new context with the given provider added to +// the existing providers. This is useful when you want to add a provider +// without replacing existing ones. +func AddErrorContextProvider(ctx context.Context, provider ErrorContextProvider) context.Context { + existing := getErrorContextProviders(ctx) + providers := make([]ErrorContextProvider, len(existing)+1) + copy(providers, existing) + providers[len(existing)] = provider + return context.WithValue(ctx, errorContextProvidersKey{}, providers) +} + +func getErrorContextProviders(ctx context.Context) []ErrorContextProvider { + providers, _ := ctx.Value(errorContextProvidersKey{}).([]ErrorContextProvider) + return providers +} diff --git a/server/contexts/ctxerr/statistics.go b/server/contexts/ctxerr/statistics.go index e429b42f5c..95056556c1 100644 --- a/server/contexts/ctxerr/statistics.go +++ b/server/contexts/ctxerr/statistics.go @@ -3,8 +3,6 @@ package ctxerr import ( "context" "encoding/json" - - "github.com/fleetdm/fleet/v4/server/fleet" ) type ErrorAgg struct { @@ -63,12 +61,21 @@ out: return stack[:stackIdx] } +// vitalErrorData represents the structure of vital fleetd error data. +type vitalErrorData struct { + ErrorSource string `json:"error_source"` + ErrorSourceVersion string `json:"error_source_version"` + ErrorMessage string `json:"error_message"` + ErrorAdditionalInfo map[string]any `json:"error_additional_info"` + Vital bool `json:"vital"` +} + func getVitalMetadata(chain []fleetErrorJSON) json.RawMessage { for _, e := range chain { if len(e.Data) > 0 { // Currently, only vital fleetd errors contain metadata. // Note: vital errors should not contain any sensitive info - var fleetdErr fleet.FleetdError + var fleetdErr vitalErrorData var err error if err = json.Unmarshal(e.Data, &fleetdErr); err != nil || !fleetdErr.Vital { continue diff --git a/server/contexts/host/host.go b/server/contexts/host/host.go index 46b1110648..4bdcc08d9b 100644 --- a/server/contexts/host/host.go +++ b/server/contexts/host/host.go @@ -22,3 +22,33 @@ func FromContext(ctx context.Context) (*fleet.Host, bool) { host, ok := ctx.Value(hostKey).(*fleet.Host) return host, ok } + +// HostAttributeProvider wraps a fleet.Host to provide error context. +// It implements ctxerr.ErrorContextProvider. +type HostAttributeProvider struct { + Host *fleet.Host +} + +// GetDiagnosticContext implements ctxerr.ErrorContextProvider +func (p *HostAttributeProvider) GetDiagnosticContext() map[string]any { + if p.Host == nil { + return nil + } + return map[string]any{ + "host": map[string]any{ + "platform": p.Host.Platform, + "osquery_version": p.Host.OsqueryVersion, + }, + } +} + +// GetTelemetryContext implements ctxerr.ErrorContextProvider +func (p *HostAttributeProvider) GetTelemetryContext() map[string]any { + if p.Host == nil { + return nil + } + return map[string]any{ + "host.hostname": p.Host.Hostname, + "host.id": p.Host.ID, + } +} diff --git a/server/contexts/license/license.go b/server/contexts/license/license.go index a8aa638bbc..4d9a3b1d2a 100644 --- a/server/contexts/license/license.go +++ b/server/contexts/license/license.go @@ -4,23 +4,34 @@ package license import ( "context" - - "github.com/fleetdm/fleet/v4/server/fleet" ) +// LicenseChecker is the interface for checking license properties. +// This interface is implemented by fleet.LicenseInfo +type LicenseChecker interface { + IsPremium() bool + IsAllowDisableTelemetry() bool + // GetTier returns the license tier (e.g., "free", "premium", "trial"). + GetTier() string + // GetOrganization returns the name of the licensed organization. + GetOrganization() string + // GetDeviceCount returns the number of licensed devices. + GetDeviceCount() int +} + type key int const licenseKey key = 0 // NewContext creates a new context.Context with the license. -func NewContext(ctx context.Context, lic *fleet.LicenseInfo) context.Context { +func NewContext(ctx context.Context, lic LicenseChecker) context.Context { return context.WithValue(ctx, licenseKey, lic) } -// FromContext returns the license from the context and true, or nil and false -// if there is no license. -func FromContext(ctx context.Context) (*fleet.LicenseInfo, bool) { - v, ok := ctx.Value(licenseKey).(*fleet.LicenseInfo) +// FromContext returns the license from the context as a LicenseChecker interface. +// Use this when you only need to check license properties via the interface methods. +func FromContext(ctx context.Context) (LicenseChecker, bool) { + v, ok := ctx.Value(licenseKey).(LicenseChecker) return v, ok } @@ -34,6 +45,8 @@ func IsPremium(ctx context.Context) bool { return false } +// IsAllowDisableTelemetry returns true if telemetry can be disabled based on +// the license in the context. func IsAllowDisableTelemetry(ctx context.Context) bool { if lic, ok := FromContext(ctx); ok { return lic.IsAllowDisableTelemetry() diff --git a/server/contexts/logging/logging.go b/server/contexts/logging/logging.go index bd79ff8b6b..c71ff78a49 100644 --- a/server/contexts/logging/logging.go +++ b/server/contexts/logging/logging.go @@ -7,13 +7,31 @@ import ( "sync" "time" - "github.com/fleetdm/fleet/v4/server/contexts/viewer" - "github.com/fleetdm/fleet/v4/server/fleet" + platform_http "github.com/fleetdm/fleet/v4/server/platform/http" kithttp "github.com/go-kit/kit/transport/http" kitlog "github.com/go-kit/log" "github.com/go-kit/log/level" ) +// UserEmailer provides the user's email for logging purposes. +type UserEmailer interface { + Email() string +} + +type userEmailerKey struct{} + +// WithUserEmailer returns a context with the UserEmailer stored for logging. +// This should be called by authentication middleware after the user is identified. +func WithUserEmailer(ctx context.Context, emailer UserEmailer) context.Context { + return context.WithValue(ctx, userEmailerKey{}, emailer) +} + +// UserEmailerFromContext retrieves the UserEmailer from the context. +func UserEmailerFromContext(ctx context.Context) (UserEmailer, bool) { + v, ok := ctx.Value(userEmailerKey{}).(UserEmailer) + return v, ok +} + type key int const loggingKey key = 0 @@ -138,9 +156,8 @@ func (l *LoggingContext) Log(ctx context.Context, logger kitlog.Logger) { if !l.SkipUser { loggedInUser := "unauthenticated" - vc, ok := viewer.FromContext(ctx) - if ok { - loggedInUser = vc.Email() + if emailer, ok := UserEmailerFromContext(ctx); ok { + loggedInUser = emailer.Email() } keyvals = append(keyvals, "user", loggedInUser) } @@ -168,7 +185,7 @@ func (l *LoggingContext) Log(ctx context.Context, logger kitlog.Logger) { ) separator := " || " for _, err := range l.Errs { - var ewi fleet.ErrWithInternal + var ewi platform_http.ErrWithInternal if errors.As(err, &ewi) { if internalErrs == "" { internalErrs = ewi.Internal() @@ -182,7 +199,7 @@ func (l *LoggingContext) Log(ctx context.Context, logger kitlog.Logger) { errs += separator + err.Error() } } - var ewuuid fleet.ErrorUUIDer + var ewuuid platform_http.ErrorUUIDer if errors.As(err, &ewuuid) { if uuid := ewuuid.UUID(); uuid != "" { uuids = append(uuids, uuid) @@ -209,7 +226,7 @@ func (l *LoggingContext) setLevelError() bool { } if len(l.Errs) == 1 { - var ew fleet.ErrWithIsClientError + var ew platform_http.ErrWithIsClientError if errors.As(l.Errs[0], &ew) && ew.IsClientError() { return false } diff --git a/server/contexts/viewer/viewer.go b/server/contexts/viewer/viewer.go index 8294c31298..aa642118d7 100644 --- a/server/contexts/viewer/viewer.go +++ b/server/contexts/viewer/viewer.go @@ -4,6 +4,7 @@ package viewer import ( "context" + "strings" "github.com/fleetdm/fleet/v4/server/fleet" ) @@ -102,3 +103,38 @@ func (v Viewer) CanPerformPasswordReset() bool { } return false } + +// GetDiagnosticContext implements ctxerr.ErrorContextProvider +func (v *Viewer) GetDiagnosticContext() map[string]any { + vdata := map[string]any{ + "is_logged_in": v.IsLoggedIn(), + } + if v.User != nil { + vdata["sso_enabled"] = v.User.SSOEnabled + } + return map[string]any{ + "viewer": vdata, + } +} + +// GetTelemetryContext implements ctxerr.ErrorContextProvider +func (v *Viewer) GetTelemetryContext() map[string]any { + if v.User == nil { + return nil + } + return map[string]any{ + "user.id": v.User.ID, + "user.email": maskEmail(v.User.Email), + } +} + +// maskEmail anonymizes an email address for telemetry by showing only +// the first character of the local part and the full domain. +// Example: "john.doe@example.com" -> "j***@example.com" +func maskEmail(email string) string { + parts := strings.SplitN(email, "@", 2) + if len(parts) != 2 || len(parts[0]) == 0 { + return "***" + } + return string(parts[0][0]) + "***@" + parts[1] +} diff --git a/server/contexts/viewer/viewer_test.go b/server/contexts/viewer/viewer_test.go index c581351390..7f8ed8a0ef 100644 --- a/server/contexts/viewer/viewer_test.go +++ b/server/contexts/viewer/viewer_test.go @@ -153,3 +153,26 @@ func TestCanPerformActions(t *testing.T) { // assert.Equal(t, false, needsPasswordResetAdminViewer.CanPerformWriteActionOnUser(1)) // assert.Equal(t, true, needsPasswordResetAdminViewer.CanPerformWriteActionOnUser(needsPasswordResetAdminViewer.User.ID)) // } + +func TestMaskEmail(t *testing.T) { + cases := []struct { + name string + email string + expected string + }{ + {"standard email", "john.doe@example.com", "j***@example.com"}, + {"single char local", "j@example.com", "j***@example.com"}, + {"subdomain", "user@mail.example.com", "u***@mail.example.com"}, + {"empty string", "", "***"}, + {"no at sign", "invalid", "***"}, + {"empty local part", "@example.com", "***"}, + {"only at sign", "@", "***"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + result := maskEmail(tc.email) + assert.Equal(t, tc.expected, result) + }) + } +} diff --git a/server/datastore/mysql/statistics.go b/server/datastore/mysql/statistics.go index 80be1c2860..8567e64ee1 100644 --- a/server/datastore/mysql/statistics.go +++ b/server/datastore/mysql/statistics.go @@ -144,7 +144,7 @@ func (ds *Datastore) ShouldSendStatistics(ctx context.Context, frequency time.Du stats.NumHostsNotResponding = amountHostsNotResponding stats.Organization = "unknown" if lic != nil && lic.IsPremium() { - stats.Organization = lic.Organization + stats.Organization = lic.GetOrganization() } stats.AIFeaturesDisabled = appConfig.ServerSettings.AIFeaturesDisabled stats.MaintenanceWindowsConfigured = len(appConfig.Integrations.GoogleCalendar) > 0 && appConfig.Integrations.GoogleCalendar[0].Domain != "" && len(appConfig.Integrations.GoogleCalendar[0].ApiKey) > 0 @@ -187,7 +187,7 @@ func (ds *Datastore) ShouldSendStatistics(ctx context.Context, frequency time.Du LicenseTier: fleet.TierFree, } if lic != nil { - stats.LicenseTier = lic.Tier + stats.LicenseTier = lic.GetTier() } if err := computeStats(&stats, time.Now().Add(-frequency)); err != nil { return fleet.StatisticsPayload{}, false, ctxerr.Wrap(ctx, err, "compute statistics") @@ -212,7 +212,7 @@ func (ds *Datastore) ShouldSendStatistics(ctx context.Context, frequency time.Du LicenseTier: fleet.TierFree, } if lic != nil { - stats.LicenseTier = lic.Tier + stats.LicenseTier = lic.GetTier() } if err := computeStats(&stats, lastUpdated); err != nil { return fleet.StatisticsPayload{}, false, ctxerr.Wrap(ctx, err, "compute statistics") diff --git a/server/fleet/app.go b/server/fleet/app.go index a65ea62293..fb1375eb99 100644 --- a/server/fleet/app.go +++ b/server/fleet/app.go @@ -1476,6 +1476,24 @@ func (l *LicenseInfo) IsAllowDisableTelemetry() bool { return !l.IsPremium() || l.AllowDisableTelemetry } +// Tier returns the license tier. +// This method implements license.LicenseChecker. +func (l *LicenseInfo) GetTier() string { + return l.Tier +} + +// Organization returns the name of the licensed organization. +// This method implements license.LicenseChecker. +func (l *LicenseInfo) GetOrganization() string { + return l.Organization +} + +// DeviceCount returns the number of licensed devices. +// This method implements license.LicenseChecker. +func (l *LicenseInfo) GetDeviceCount() int { + return l.DeviceCount +} + const ( HeaderLicenseKey = "X-Fleet-License" HeaderLicenseValueExpired = "Expired" diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index b1a635e5b2..6b85c2c0c3 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -19,6 +19,7 @@ import ( "github.com/fleetdm/fleet/v4/server/mdm/nanodep/godep" "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm" "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/storage" + platform_http "github.com/fleetdm/fleet/v4/server/platform/http" "github.com/jmoiron/sqlx" ) @@ -2836,19 +2837,11 @@ type AlreadyExistsError interface { IsExists() bool } -// ForeignKeyError is returned when the operation fails due to foreign key constraints. -type ForeignKeyError interface { - error - IsForeignKey() bool -} +// ForeignKeyError is an alias for platform_http.ForeignKeyError. +type ForeignKeyError = platform_http.ForeignKeyError -func IsForeignKey(err error) bool { - var fke ForeignKeyError - if errors.As(err, &fke) { - return fke.IsForeignKey() - } - return false -} +// IsForeignKey is an alias for platform_http.IsForeignKey. +var IsForeignKey = platform_http.IsForeignKey type OptionalArg func() interface{} diff --git a/server/fleet/errors.go b/server/fleet/errors.go index 2544401efa..19f82c87b4 100644 --- a/server/fleet/errors.go +++ b/server/fleet/errors.go @@ -1,22 +1,18 @@ package fleet import ( - "encoding/json" "errors" "fmt" "net/http" - "reflect" - "regexp" - "strings" "time" - "github.com/google/uuid" + platform_http "github.com/fleetdm/fleet/v4/server/platform/http" "github.com/rs/zerolog" ) var ( ErrNoContext = errors.New("context key not set") - ErrPasswordResetRequired = &passwordResetRequiredError{} + ErrPasswordResetRequired = platform_http.ErrPasswordResetRequired ErrMissingLicense = &licenseError{} ErrMDMNotConfigured = &MDMNotConfiguredError{} ErrWindowsMDMNotConfigured = &WindowsMDMNotConfiguredError{} @@ -46,40 +42,17 @@ type ErrWithStatusCode interface { StatusCode() int } -// ErrWithInternal is an interface for errors that include extra "internal" -// information that should be logged in server logs but not sent to clients. -type ErrWithInternal interface { - error - // Internal returns the error string that must only be logged internally, - // not returned to the client. - Internal() string -} +// ErrWithInternal is an alias for platform_http.ErrWithInternal. +type ErrWithInternal = platform_http.ErrWithInternal -// ErrWithLogFields is an interface for errors that include additional logging -// fields that should be logged in server logs but not sent to clients. -type ErrWithLogFields interface { - error - // LogFields returns the additional log fields to add, which should come in - // key, value pairs (as used in go-kit log). - LogFields() []interface{} -} +// ErrWithLogFields is an alias for platform_http.ErrWithLogFields. +type ErrWithLogFields = platform_http.ErrWithLogFields -// ErrWithRetryAfter is an interface for errors that should set a specific HTTP -// Header Retry-After value (see -// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After) -type ErrWithRetryAfter interface { - error - // RetryAfter returns the number of seconds to wait before retry. - RetryAfter() int -} +// ErrWithRetryAfter is an alias for platform_http.ErrWithRetryAfter. +type ErrWithRetryAfter = platform_http.ErrWithRetryAfter -// ErrWithIsClientError is an interface for errors that explicitly specify -// whether they are client errors or not. By default, errors are treated as -// server errors. -type ErrWithIsClientError interface { - error - IsClientError() bool -} +// ErrWithIsClientError is an alias for platform_http.ErrWithIsClientError. +type ErrWithIsClientError = platform_http.ErrWithIsClientError type invalidArgWithStatusError struct { InvalidArgumentError @@ -94,30 +67,11 @@ func (e invalidArgWithStatusError) Status() int { return e.code } -// ErrorUUIDer is the interface for errors that contain a UUID. -type ErrorUUIDer interface { - // UUID returns the error's UUID. - UUID() string -} +// ErrorUUIDer is an alias for platform_http.ErrorUUIDer. +type ErrorUUIDer = platform_http.ErrorUUIDer -// ErrorWithUUID can be embedded to error types to implement ErrorUUIDer. -type ErrorWithUUID struct { - uuid string -} - -var _ ErrorUUIDer = (*ErrorWithUUID)(nil) - -// UUID implements the ErrorUUIDer interface. -func (e *ErrorWithUUID) UUID() string { - if e.uuid == "" { - uuid, err := uuid.NewRandom() - if err != nil { - panic(err) - } - e.uuid = uuid.String() - } - return e.uuid -} +// ErrorWithUUID is an alias for platform_http.ErrorWithUUID. +type ErrorWithUUID = platform_http.ErrorWithUUID // InvalidArgumentError is the error returned when invalid data is presented to // a service method. It is a client error. @@ -187,102 +141,26 @@ func (e InvalidArgumentError) Invalid() []map[string]string { return invalid } -// BadRequestError is an error type that generates a 400 status code. -type BadRequestError struct { - Message string - InternalErr error +// BadRequestError is an alias for platform_http.BadRequestError. +type BadRequestError = platform_http.BadRequestError - ErrorWithUUID -} +// AuthFailedError is an alias for platform_http.AuthFailedError. +type AuthFailedError = platform_http.AuthFailedError -// Error returns the error message. -func (e *BadRequestError) Error() string { - return e.Message -} +// NewAuthFailedError is an alias for platform_http.NewAuthFailedError. +var NewAuthFailedError = platform_http.NewAuthFailedError -// This implements the interface required by the server/service package logic -// to determine the status code to return to the client. -func (e *BadRequestError) BadRequestError() []map[string]string { - return nil -} +// AuthRequiredError is an alias for platform_http.AuthRequiredError. +type AuthRequiredError = platform_http.AuthRequiredError -func (e BadRequestError) Internal() string { - if e.InternalErr == nil { - return "" - } - return e.InternalErr.Error() -} +// NewAuthRequiredError is an alias for platform_http.NewAuthRequiredError. +var NewAuthRequiredError = platform_http.NewAuthRequiredError -type AuthFailedError struct { - // internal is the reason that should only be logged internally - internal string +// AuthHeaderRequiredError is an alias for platform_http.AuthHeaderRequiredError. +type AuthHeaderRequiredError = platform_http.AuthHeaderRequiredError - ErrorWithUUID -} - -func NewAuthFailedError(internal string) *AuthFailedError { - return &AuthFailedError{internal: internal} -} - -func (e AuthFailedError) Error() string { - return "Authentication failed" -} - -func (e AuthFailedError) Internal() string { - return e.internal -} - -func (e AuthFailedError) StatusCode() int { - return http.StatusUnauthorized -} - -type AuthRequiredError struct { - // internal is the reason that should only be logged internally - internal string - - ErrorWithUUID -} - -func NewAuthRequiredError(internal string) *AuthRequiredError { - return &AuthRequiredError{internal: internal} -} - -func (e AuthRequiredError) Error() string { - return "Authentication required" -} - -func (e AuthRequiredError) Internal() string { - return e.internal -} - -func (e AuthRequiredError) StatusCode() int { - return http.StatusUnauthorized -} - -type AuthHeaderRequiredError struct { - // internal is the reason that should only be logged internally - internal string - - ErrorWithUUID -} - -func NewAuthHeaderRequiredError(internal string) *AuthHeaderRequiredError { - return &AuthHeaderRequiredError{ - internal: internal, - } -} - -func (e AuthHeaderRequiredError) Error() string { - return "Authorization header required" -} - -func (e AuthHeaderRequiredError) Internal() string { - return e.internal -} - -func (e AuthHeaderRequiredError) StatusCode() int { - return http.StatusUnauthorized -} +// NewAuthHeaderRequiredError is an alias for platform_http.NewAuthHeaderRequiredError. +var NewAuthHeaderRequiredError = platform_http.NewAuthHeaderRequiredError // PermissionError, set when user is authenticated, but not allowed to perform action type PermissionError struct { @@ -347,18 +225,6 @@ func (e licenseError) StatusCode() int { return http.StatusPaymentRequired } -type passwordResetRequiredError struct { - ErrorWithUUID -} - -func (e passwordResetRequiredError) Error() string { - return "password reset required" -} - -func (e passwordResetRequiredError) StatusCode() int { - return http.StatusUnauthorized -} - // MDMNotConfiguredError is used when an MDM endpoint or resource is accessed // without having MDM correctly configured. type MDMNotConfiguredError struct{} @@ -453,15 +319,9 @@ func (e *GatewayError) Error() string { return msg } -// Error is a user facing error (API user). It's meant to be used for errors that are -// related to fleet logic specifically. Other errors, such as mysql errors, shouldn't -// be translated to this. -type Error struct { - Code int `json:"code,omitempty"` - Message string `json:"message,omitempty"` - - ErrorWithUUID -} +// Error is an alias for platform_http.Error. +// It's meant to be used for errors that are related to fleet logic specifically. +type Error = platform_http.Error const ( // ErrNoRoleNeeded is the error number for valid role needed @@ -491,101 +351,26 @@ func NewErrorf(code int, format string, args ...interface{}) error { } } -func (ge *Error) Error() string { - return ge.Message -} +// UserMessageError is an alias for platform_http.UserMessageError. +type UserMessageError = platform_http.UserMessageError -// UserMessageError is an error that adds the UserMessage interface -// implementation. -type UserMessageError struct { - error - statusCode int - - ErrorWithUUID -} - -// NewUserMessageError creates a UserMessageError that will translate the -// error message of err to a user-friendly form. If statusCode is > 0, it -// will be used as the HTTP status code for the error, otherwise it defaults -// to http.StatusUnprocessableEntity (422). -func NewUserMessageError(err error, statusCode int) *UserMessageError { - if err == nil { - return nil - } - return &UserMessageError{ - error: err, - statusCode: statusCode, - } -} - -var rxJSONUnknownField = regexp.MustCompile(`^json: unknown field "(.+)"$`) +// NewUserMessageError is an alias for platform_http.NewUserMessageError. +var NewUserMessageError = platform_http.NewUserMessageError // IsJSONUnknownFieldError returns true if err is a JSON unknown field error. // There is no exported type or value for this error, so we have to match the // error message. func IsJSONUnknownFieldError(err error) bool { - return rxJSONUnknownField.MatchString(err.Error()) + return platform_http.IsJSONUnknownFieldError(err) } +// GetJSONUnknownField returns the unknown field name from a JSON unknown field error. func GetJSONUnknownField(err error) *string { - errCause := Cause(err) - if IsJSONUnknownFieldError(errCause) { - substr := rxJSONUnknownField.FindStringSubmatch(errCause.Error()) - return &substr[1] - } - return nil -} - -// UserMessage implements the user-friendly translation of the error if its -// root cause is one of the supported types, otherwise it returns the error -// message. -func (e UserMessageError) UserMessage() string { - cause := Cause(e.error) - switch cause := cause.(type) { - case *json.UnmarshalTypeError: - var sb strings.Builder - curType := cause.Type - for curType.Kind() == reflect.Slice || curType.Kind() == reflect.Array { - sb.WriteString("array of ") - curType = curType.Elem() - } - sb.WriteString(curType.Name()) - if curType != cause.Type { - // it was an array - sb.WriteString("s") - } - - return fmt.Sprintf("invalid value type at '%s': expected %s but got %s", cause.Field, sb.String(), cause.Value) - - default: - // there's no specific error type for the strict json mode - // (DisallowUnknownFields), so resort to message-matching. - if matches := rxJSONUnknownField.FindStringSubmatch(cause.Error()); matches != nil { - return fmt.Sprintf("unsupported key provided: %q", matches[1]) - } - return e.Error() - } -} - -// StatusCode implements the kithttp.StatusCoder interface to return the status -// code to use in HTTP API responses. -func (e UserMessageError) StatusCode() int { - if e.statusCode > 0 { - return e.statusCode - } - return http.StatusUnprocessableEntity + return platform_http.GetJSONUnknownField(err) } // Cause returns the root error in err's chain. -func Cause(err error) error { - for { - uerr := errors.Unwrap(err) - if uerr == nil { - return err - } - err = uerr - } -} +var Cause = platform_http.Cause // FleetdError is an error that can be reported by any of the fleetd // components. diff --git a/server/mdm/android/arch_test.go b/server/mdm/android/arch_test.go index ab4991fde8..e69e398b89 100644 --- a/server/mdm/android/arch_test.go +++ b/server/mdm/android/arch_test.go @@ -7,32 +7,34 @@ import ( "github.com/fleetdm/fleet/v4/server/archtest" ) +const m = archtest.ModuleName + // TestAllAndroidPackageDependencies checks that android packages are not dependent on other Fleet packages // to maintain decoupling and modularity. // If coupling is necessary, it should be done in the main server/fleet, server/service, or other package. func TestAllAndroidPackageDependencies(t *testing.T) { t.Parallel() - archtest.NewPackageTest(t, "github.com/fleetdm/fleet/v4/server/mdm/android..."). + archtest.NewPackageTest(t, m+"/server/mdm/android..."). OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)). ShouldNotDependOn( - "github.com/fleetdm/fleet/v4/server/service...", - "github.com/fleetdm/fleet/v4/server/datastore/mysql...", + m+"/server/service...", + m+"/server/datastore/mysql...", ). IgnoreRecursively( - "github.com/fleetdm/fleet/v4/server/mdm/android/tests", + m+"/server/mdm/android/tests", ). IgnoreDeps( // Android packages - "github.com/fleetdm/fleet/v4/server/mdm/android...", + m+"/server/mdm/android...", // Other/infra packages - "github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql", - "github.com/fleetdm/fleet/v4/server/service/externalsvc", // dependency on Jira and Zendesk - "github.com/fleetdm/fleet/v4/server/service/middleware/auth", - "github.com/fleetdm/fleet/v4/server/service/middleware/authzcheck", - "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils", - "github.com/fleetdm/fleet/v4/server/service/middleware/log", - "github.com/fleetdm/fleet/v4/server/service/middleware/ratelimit", - "github.com/fleetdm/fleet/v4/server/service/modules/activities", + m+"/server/datastore/mysql/common_mysql", + m+"/server/service/externalsvc", // dependency on Jira and Zendesk + m+"/server/service/middleware/auth", + m+"/server/service/middleware/authzcheck", + m+"/server/service/middleware/endpoint_utils", + m+"/server/service/middleware/log", + m+"/server/service/middleware/ratelimit", + m+"/server/service/modules/activities", ). Check() } @@ -42,7 +44,8 @@ func TestAllAndroidPackageDependencies(t *testing.T) { // If coupling is necessary, it should be done in the main server/fleet or another package. func TestAndroidPackageDependencies(t *testing.T) { t.Parallel() - archtest.NewPackageTest(t, "github.com/fleetdm/fleet/v4/server/mdm/android"). + archtest.NewPackageTest(t, m+"/server/mdm/android"). OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)). - ShouldNotDependOn("github.com/fleetdm/fleet/v4/...") + ShouldNotDependOn(m + "/..."). + Check() } diff --git a/server/mdm/android/service/endpoint_utils.go b/server/mdm/android/service/endpoint_utils.go index 2638dbecbe..4d04c92286 100644 --- a/server/mdm/android/service/endpoint_utils.go +++ b/server/mdm/android/service/endpoint_utils.go @@ -7,6 +7,7 @@ import ( "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mdm/android" + platform_http "github.com/fleetdm/fleet/v4/server/platform/http" "github.com/fleetdm/fleet/v4/server/service/middleware/auth" eu "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" "github.com/go-json-experiment/json" @@ -21,13 +22,14 @@ func encodeResponse(ctx context.Context, w http.ResponseWriter, response interfa func(w http.ResponseWriter, response interface{}) error { return json.MarshalWrite(w, response, jsontext.WithIndent(" ")) }, + nil, // no domain-specific error encoder ) } func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc { return eu.MakeDecoder(iface, func(body io.Reader, req any) error { return json.UnmarshalRead(body, req) - }, nil, nil, nil) + }, nil, nil, nil, nil) } // handlerFunc is the handler function type for Android service endpoints. @@ -40,12 +42,12 @@ type endpointer struct { svc android.Service } -func (e *endpointer) CallHandlerFunc(f handlerFunc, ctx context.Context, request interface{}, - svc interface{}) (fleet.Errorer, error) { +func (e *endpointer) CallHandlerFunc(f handlerFunc, ctx context.Context, request any, + svc any) (platform_http.Errorer, error) { return f(ctx, request, svc.(android.Service)), nil } -func (e *endpointer) Service() interface{} { +func (e *endpointer) Service() any { return e.svc } diff --git a/server/mdm/android/tests/testing_utils.go b/server/mdm/android/tests/testing_utils.go index c8a3ebae86..bb8ab7bd82 100644 --- a/server/mdm/android/tests/testing_utils.go +++ b/server/mdm/android/tests/testing_utils.go @@ -207,13 +207,18 @@ func (m *mockService) NewActivity(ctx context.Context, user *fleet.User, details } func runServerForTests(t *testing.T, logger kitlog.Logger, fleetSvc fleet.Service, androidSvc android.Service) *httptest.Server { + // androidErrorEncoder wraps EncodeError with nil domain encoder for android tests + androidErrorEncoder := func(ctx context.Context, err error, w http.ResponseWriter) { + endpoint_utils.EncodeError(ctx, err, w, nil) + } + fleetAPIOptions := []kithttp.ServerOption{ kithttp.ServerBefore( kithttp.PopulateRequestContext, auth.SetRequestsContexts(fleetSvc), ), kithttp.ServerErrorHandler(&endpoint_utils.ErrorHandler{Logger: logger}), - kithttp.ServerErrorEncoder(endpoint_utils.EncodeError), + kithttp.ServerErrorEncoder(androidErrorEncoder), kithttp.ServerAfter( kithttp.SetContentType("application/json; charset=utf-8"), log.LogRequestEnd(logger), diff --git a/server/platform/arch_test.go b/server/platform/arch_test.go new file mode 100644 index 0000000000..99761368d0 --- /dev/null +++ b/server/platform/arch_test.go @@ -0,0 +1,44 @@ +package platform_test + +import ( + "regexp" + "testing" + + "github.com/fleetdm/fleet/v4/server/archtest" +) + +const m = archtest.ModuleName + +// TestEndpointerPackageDependencies checks that endpointer package is not dependent on other Fleet domain packages +// to maintain decoupling and modularity. +func TestEndpointerPackageDependencies(t *testing.T) { + t.Parallel() + archtest.NewPackageTest(t, m+"/server/service/middleware/endpoint_utils"). + OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)). + WithTests(). + ShouldNotDependOn(m+"/..."). + IgnoreDeps( + // Platform packages + m+"/server/platform...", + // Other infra packages + m+"/server/contexts/authz", + m+"/server/contexts/ctxerr", + m+"/server/contexts/license", + m+"/server/contexts/logging", + m+"/server/contexts/publicip", + m+"/server/service/middleware/authzcheck", + m+"/server/service/middleware/ratelimit", + ). + Check() +} + +// TestPlatformPackageDependencies checks that platform packages are NOT dependent on ANY other Fleet packages +// to maintain decoupling and modularity. +func TestPlatformPackageDependencies(t *testing.T) { + t.Parallel() + archtest.NewPackageTest(t, m+"/server/platform..."). + OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)). + WithTests(). + ShouldNotDependOn(m + "/..."). + Check() +} diff --git a/server/platform/http/errors.go b/server/platform/http/errors.go new file mode 100644 index 0000000000..4eb42de82b --- /dev/null +++ b/server/platform/http/errors.go @@ -0,0 +1,357 @@ +package http + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "reflect" + "regexp" + "strings" + + "github.com/google/uuid" +) + +// ErrWithInternal defines an interface for errors that have an internal message +// that should only be logged, not returned to the client. +type ErrWithInternal interface { + error + // Internal returns the error string that must only be logged internally, + // not returned to the client. + Internal() string +} + +// ErrWithLogFields defines an interface for errors that have additional log fields. +type ErrWithLogFields interface { + error + // LogFields returns the additional log fields to add, which should come in + // key, value pairs (as used in go-kit log). + LogFields() []any +} + +// ErrorUUIDer defines an interface for errors that have a UUID for tracking. +type ErrorUUIDer interface { + // UUID returns the error's UUID. + UUID() string +} + +// ErrorWithUUID can be embedded in error types to implement ErrorUUIDer. +type ErrorWithUUID struct { + uuid string +} + +var _ ErrorUUIDer = (*ErrorWithUUID)(nil) + +// UUID implements the ErrorUUIDer interface. +func (e *ErrorWithUUID) UUID() string { + if e.uuid == "" { + u, err := uuid.NewRandom() + if err != nil { + panic(err) + } + e.uuid = u.String() + } + return e.uuid +} + +// BadRequestError is the error returned when the request is invalid. +type BadRequestError struct { + Message string + InternalErr error + + ErrorWithUUID +} + +// Error returns the error message. +func (e *BadRequestError) Error() string { + return e.Message +} + +// BadRequestError implements the interface required by the server/service package logic +// to determine the status code to return to the client. +func (e *BadRequestError) BadRequestError() []map[string]string { + return nil +} + +// Internal implements the ErrWithInternal interface. +func (e BadRequestError) Internal() string { + if e.InternalErr != nil { + return e.InternalErr.Error() + } + return "" +} + +// UserMessageError is an error that wraps another error with a user-friendly message. +type UserMessageError struct { + error + statusCode int + + ErrorWithUUID +} + +// NewUserMessageError creates a UserMessageError that will translate the +// error message of err to a user-friendly form. If statusCode is > 0, it +// will be used as the HTTP status code for the error, otherwise it defaults +// to http.StatusUnprocessableEntity (422). +func NewUserMessageError(err error, statusCode int) *UserMessageError { + if err == nil { + return nil + } + return &UserMessageError{ + error: err, + statusCode: statusCode, + } +} + +// StatusCode returns the HTTP status code for this error. +func (e UserMessageError) StatusCode() int { + if e.statusCode > 0 { + return e.statusCode + } + return http.StatusUnprocessableEntity +} + +var rxJSONUnknownField = regexp.MustCompile(`^json: unknown field "(.+)"$`) + +// IsJSONUnknownFieldError returns true if err is a JSON unknown field error. +// There is no exported type or value for this error, so we have to match the +// error message. +func IsJSONUnknownFieldError(err error) bool { + return rxJSONUnknownField.MatchString(err.Error()) +} + +// GetJSONUnknownField returns the unknown field name from a JSON unknown field error. +func GetJSONUnknownField(err error) *string { + errCause := Cause(err) + if IsJSONUnknownFieldError(errCause) { + substr := rxJSONUnknownField.FindStringSubmatch(errCause.Error()) + return &substr[1] + } + return nil +} + +// UserMessage implements the user-friendly translation of the error if its +// root cause is one of the supported types, otherwise it returns the error +// message. +func (e UserMessageError) UserMessage() string { + cause := Cause(e.error) + switch cause := cause.(type) { + case *json.UnmarshalTypeError: + var sb strings.Builder + curType := cause.Type + for curType.Kind() == reflect.Slice || curType.Kind() == reflect.Array { + sb.WriteString("array of ") + curType = curType.Elem() + } + sb.WriteString(curType.Name()) + if curType != cause.Type { + // it was an array + sb.WriteString("s") + } + + return fmt.Sprintf("invalid value type at '%s': expected %s but got %s", cause.Field, sb.String(), cause.Value) + + default: + // there's no specific error type for the strict json mode + // (DisallowUnknownFields), so resort to message-matching. + if matches := rxJSONUnknownField.FindStringSubmatch(cause.Error()); matches != nil { + return fmt.Sprintf("unsupported key provided: %q", matches[1]) + } + return e.Error() + } +} + +// Cause returns the root error in err's chain. +func Cause(err error) error { + for { + uerr := errors.Unwrap(err) + if uerr == nil { + return err + } + err = uerr + } +} + +// ErrWithRetryAfter is an interface for errors that should set a specific HTTP +// Header Retry-After value (see +// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After) +type ErrWithRetryAfter interface { + error + // RetryAfter returns the number of seconds to wait before retry. + RetryAfter() int +} + +// ForeignKeyError is an interface for errors caused by foreign key constraint violations. +type ForeignKeyError interface { + error + IsForeignKey() bool +} + +// IsForeignKey returns true if err is a foreign key constraint violation. +func IsForeignKey(err error) bool { + var fke ForeignKeyError + if errors.As(err, &fke) { + return fke.IsForeignKey() + } + return false +} + +// Error is a generic error type with a code and message. +type Error struct { + Code int `json:"code,omitempty"` + Message string `json:"message,omitempty"` + + ErrorWithUUID +} + +// Error returns the error message. +func (e *Error) Error() string { + return e.Message +} + +// ErrWithIsClientError is an interface for errors that explicitly specify +// whether they are client errors or not. By default, errors are treated as +// server errors. +type ErrWithIsClientError interface { + error + IsClientError() bool +} + +// AuthFailedError is returned when authentication fails. +type AuthFailedError struct { + // internal is the reason that should only be logged internally + internal string + + ErrorWithUUID +} + +// NewAuthFailedError creates a new AuthFailedError. +func NewAuthFailedError(internal string) *AuthFailedError { + return &AuthFailedError{internal: internal} +} + +// Error implements the error interface. +func (e AuthFailedError) Error() string { + return "Authentication failed" +} + +// Internal implements ErrWithInternal. +func (e AuthFailedError) Internal() string { + return e.internal +} + +// StatusCode implements kithttp.StatusCoder. +func (e AuthFailedError) StatusCode() int { + return http.StatusUnauthorized +} + +// AuthRequiredError is returned when authentication is required. +type AuthRequiredError struct { + // internal is the reason that should only be logged internally + internal string + + ErrorWithUUID +} + +// NewAuthRequiredError creates a new AuthRequiredError. +func NewAuthRequiredError(internal string) *AuthRequiredError { + return &AuthRequiredError{internal: internal} +} + +// Error implements the error interface. +func (e AuthRequiredError) Error() string { + return "Authentication required" +} + +// Internal implements ErrWithInternal. +func (e AuthRequiredError) Internal() string { + return e.internal +} + +// StatusCode implements kithttp.StatusCoder. +func (e AuthRequiredError) StatusCode() int { + return http.StatusUnauthorized +} + +// AuthHeaderRequiredError is returned when an authorization header is required. +type AuthHeaderRequiredError struct { + // internal is the reason that should only be logged internally + internal string + + ErrorWithUUID +} + +// NewAuthHeaderRequiredError creates a new AuthHeaderRequiredError. +func NewAuthHeaderRequiredError(internal string) *AuthHeaderRequiredError { + return &AuthHeaderRequiredError{ + internal: internal, + } +} + +// Error implements the error interface. +func (e AuthHeaderRequiredError) Error() string { + return "Authorization header required" +} + +// Internal implements ErrWithInternal. +func (e AuthHeaderRequiredError) Internal() string { + return e.internal +} + +// StatusCode implements kithttp.StatusCoder. +func (e AuthHeaderRequiredError) StatusCode() int { + return http.StatusUnauthorized +} + +// ErrPasswordResetRequired is returned when a password reset is required. +var ErrPasswordResetRequired = &passwordResetRequiredError{} + +type passwordResetRequiredError struct { + ErrorWithUUID +} + +// Error implements the error interface. +func (e passwordResetRequiredError) Error() string { + return "password reset required" +} + +// StatusCode implements kithttp.StatusCoder. +func (e passwordResetRequiredError) StatusCode() int { + return http.StatusUnauthorized +} + +// ForbiddenErrorMessage is the error message that should be returned to +// clients when an action is forbidden. It is intentionally vague to prevent +// disclosing information that a client should not have access to. +const ForbiddenErrorMessage = "forbidden" + +// CheckMissing is the error to return when no authorization check was performed +// by the service. +type CheckMissing struct { + response any + + ErrorWithUUID +} + +// CheckMissingWithResponse creates a new error indicating the authorization +// check was missed, and including the response for further analysis by the error +// encoder. +func CheckMissingWithResponse(response any) *CheckMissing { + return &CheckMissing{response: response} +} + +// Error implements the error interface. +func (e *CheckMissing) Error() string { + return ForbiddenErrorMessage +} + +// Internal implements the ErrWithInternal interface. +func (e *CheckMissing) Internal() string { + return "Missing authorization check" +} + +// Response returns the response that was generated before the authorization +// check was found to be missing. +func (e *CheckMissing) Response() any { + return e.response +} diff --git a/server/platform/http/response.go b/server/platform/http/response.go new file mode 100644 index 0000000000..8e08f96973 --- /dev/null +++ b/server/platform/http/response.go @@ -0,0 +1,7 @@ +// Package http provides HTTP types for bounded contexts. +package http + +// Errorer is implemented by response types that may contain errors. +type Errorer interface { + Error() error +} diff --git a/server/service/appconfig.go b/server/service/appconfig.go index 7c73f153d7..d04e6331aa 100644 --- a/server/service/appconfig.go +++ b/server/service/appconfig.go @@ -108,7 +108,7 @@ func getAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet.Se if err != nil { return nil, err } - license, err := svc.License(ctx) + lic, err := svc.License(ctx) if err != nil { return nil, err } @@ -183,7 +183,7 @@ func getAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet.Se transparencyURL := fleet.DefaultTransparencyURL // Fleet Premium license is required for custom transparency url - if license.IsPremium() && appConfig.FleetDesktop.TransparencyURL != "" { + if lic.IsPremium() && appConfig.FleetDesktop.TransparencyURL != "" { transparencyURL = appConfig.FleetDesktop.TransparencyURL } fleetDesktop := fleet.FleetDesktopSettings{TransparencyURL: transparencyURL} @@ -218,7 +218,7 @@ func getAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet.Se appConfigResponseFields: appConfigResponseFields{ UpdateInterval: updateIntervalConfig, Vulnerabilities: vulnConfig, - License: license, + License: lic, Logging: loggingConfig, Email: emailConfig, SandboxEnabled: svc.SandboxEnabled(), @@ -274,7 +274,8 @@ func modifyAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet } // We do not use svc.License(ctx) to allow roles (like GitOps) write but not read access to AppConfig. - license, _ := license.FromContext(ctx) + licChecker, _ := license.FromContext(ctx) + lic, _ := licChecker.(*fleet.LicenseInfo) loggingConfig, err := svc.LoggingConfig(ctx) if err != nil { @@ -283,14 +284,14 @@ func modifyAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet response := appConfigResponse{ AppConfig: *appConfig, appConfigResponseFields: appConfigResponseFields{ - License: license, + License: lic, Logging: loggingConfig, }, } response.Obfuscate() - if (!license.IsPremium()) || response.FleetDesktop.TransparencyURL == "" { + if lic == nil || (!lic.IsPremium()) || response.FleetDesktop.TransparencyURL == "" { response.FleetDesktop.TransparencyURL = fleet.DefaultTransparencyURL } @@ -319,7 +320,8 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle oldAppConfig := appConfig.Copy() // We do not use svc.License(ctx) to allow roles (like GitOps) write but not read access to AppConfig. - license, _ := license.FromContext(ctx) + licChecker, _ := license.FromContext(ctx) + lic, _ := licChecker.(*fleet.LicenseInfo) var oldSMTPSettings fleet.SMTPSettings if appConfig.SMTPSettings != nil { @@ -359,7 +361,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle // default transparency URL is https://fleetdm.com/transparency so you are allowed to apply as long as it's not changing if newAppConfig.FleetDesktop.TransparencyURL != "" && newAppConfig.FleetDesktop.TransparencyURL != fleet.DefaultTransparencyURL { - if !license.IsPremium() { + if !lic.IsPremium() { invalid.Append("transparency_url", ErrMissingLicense.Error()) return nil, ctxerr.Wrap(ctx, invalid) } @@ -370,7 +372,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle } if newAppConfig.SSOSettings != nil { - validateSSOSettings(newAppConfig, appConfig, invalid, license) + validateSSOSettings(newAppConfig, appConfig, invalid, lic) if invalid.HasErrors() { return nil, ctxerr.Wrap(ctx, invalid) } @@ -437,7 +439,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle appConfig.MDM.MacOSSetup.EnableReleaseDeviceManually = oldAppConfig.MDM.MacOSSetup.EnableReleaseDeviceManually } if appConfig.MDM.MacOSSetup.ManualAgentInstall.Valid && appConfig.MDM.MacOSSetup.ManualAgentInstall.Value { - if !license.IsPremium() { + if !lic.IsPremium() { invalid.Append("macos_setup.manual_agent_install", ErrMissingLicense.Error()) return nil, ctxerr.Wrap(ctx, invalid) } @@ -477,7 +479,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle if newAppConfig.AgentOptions != nil { // if there were Agent Options in the new app config, then it replaced the // agent options in the resulting app config, so validate those. - if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, *appConfig.AgentOptions, license.IsPremium(), 0); err != nil { + if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, *appConfig.AgentOptions, lic.IsPremium(), 0); err != nil { err = fleet.SuggestAgentOptionsCorrection(err) err = fleet.NewUserMessageError(err, http.StatusBadRequest) if applyOpts.Force && !applyOpts.DryRun { @@ -490,7 +492,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle } // If the license is Premium, we should always send usage statisics. - if !license.IsAllowDisableTelemetry() { + if !lic.IsAllowDisableTelemetry() { appConfig.ServerSettings.EnableAnalytics = true } @@ -532,7 +534,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle isNonEmpty(newAppConfig.ConditionalAccess.OktaAudienceURI) || isNonEmpty(newAppConfig.ConditionalAccess.OktaCertificate) - if oktaFieldsBeingSet && !license.IsPremium() { + if oktaFieldsBeingSet && !lic.IsPremium() { invalid.Append("conditional_access", ErrMissingLicense.Error()) return nil, ctxerr.Wrap(ctx, invalid) } @@ -626,11 +628,11 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle appConfig.Integrations.ConditionalAccessEnabled = newAppConfig.Integrations.ConditionalAccessEnabled } - if err := svc.validateMDM(ctx, license, &oldAppConfig.MDM, &appConfig.MDM, invalid); err != nil { + if err := svc.validateMDM(ctx, lic, &oldAppConfig.MDM, &appConfig.MDM, invalid); err != nil { return nil, ctxerr.Wrap(ctx, err, "validating MDM config") } - abmAssignments, err := svc.validateABMAssignments(ctx, &newAppConfig.MDM, &oldAppConfig.MDM, invalid, license) + abmAssignments, err := svc.validateABMAssignments(ctx, &newAppConfig.MDM, &oldAppConfig.MDM, invalid, lic) if err != nil { return nil, ctxerr.Wrap(ctx, err, "validating ABM token assignments") } @@ -638,7 +640,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle var vppAssignments map[uint][]uint vppAssignmentsDefined := newAppConfig.MDM.VolumePurchasingProgram.Set && newAppConfig.MDM.VolumePurchasingProgram.Valid if vppAssignmentsDefined { - vppAssignments, err = svc.validateVPPAssignments(ctx, newAppConfig.MDM.VolumePurchasingProgram.Value, invalid, license) + vppAssignments, err = svc.validateVPPAssignments(ctx, newAppConfig.MDM.VolumePurchasingProgram.Value, invalid, lic) if err != nil { return nil, ctxerr.Wrap(ctx, err, "validating VPP token assignments") } @@ -738,7 +740,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle gitopsModeEnabled, gitopsRepoURL := appConfig.UIGitOpsMode.GitopsModeEnabled, appConfig.UIGitOpsMode.RepositoryURL if gitopsModeEnabled { - if !license.IsPremium() { + if !lic.IsPremium() { return nil, fleet.NewInvalidArgumentError("UI GitOpsMode: ", ErrMissingLicense.Error()) } if gitopsRepoURL == "" { @@ -767,7 +769,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle } - if !license.IsPremium() { + if !lic.IsPremium() { // reset transparency url to empty for downgraded licenses appConfig.FleetDesktop.TransparencyURL = "" } @@ -909,19 +911,19 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle // // Process OS updates config changes for Apple devices. // - if err := svc.processAppleOSUpdateSettings(ctx, license, fleet.MacOS, + if err := svc.processAppleOSUpdateSettings(ctx, lic, fleet.MacOS, oldAppConfig.MDM.MacOSUpdates, appConfig.MDM.MacOSUpdates, ); err != nil { return nil, ctxerr.Wrap(ctx, err, "process macOS OS updates config change") } - if err := svc.processAppleOSUpdateSettings(ctx, license, fleet.IOS, + if err := svc.processAppleOSUpdateSettings(ctx, lic, fleet.IOS, oldAppConfig.MDM.IOSUpdates, appConfig.MDM.IOSUpdates, ); err != nil { return nil, ctxerr.Wrap(ctx, err, "process iOS OS updates config change") } - if err := svc.processAppleOSUpdateSettings(ctx, license, fleet.IPadOS, + if err := svc.processAppleOSUpdateSettings(ctx, lic, fleet.IPadOS, oldAppConfig.MDM.IPadOSUpdates, appConfig.MDM.IPadOSUpdates, ); err != nil { @@ -1002,7 +1004,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle appConfig.MDM.EndUserAuthentication.SSOProviderSettings serverURLChanged := oldAppConfig.ServerSettings.ServerURL != appConfig.ServerSettings.ServerURL appleMDMUrlChanged := oldAppConfig.MDMUrl() != appConfig.MDMUrl() - if (mdmEnableEndUserAuthChanged || mdmSSOSettingsChanged || serverURLChanged || appleMDMUrlChanged) && license.IsPremium() { + if (mdmEnableEndUserAuthChanged || mdmSSOSettingsChanged || serverURLChanged || appleMDMUrlChanged) && lic.IsPremium() { if err := svc.EnterpriseOverrides.MDMAppleSyncDEPProfiles(ctx); err != nil { return nil, ctxerr.Wrap(ctx, err, "sync DEP profiles") } @@ -1100,14 +1102,14 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle // processAppleOSUpdateSettings updates the OS updates configuration if the minimum version+deadline are updated. func (svc *Service) processAppleOSUpdateSettings( ctx context.Context, - license *fleet.LicenseInfo, + lic *fleet.LicenseInfo, appleDevice fleet.AppleDevice, oldOSUpdateSettings fleet.AppleOSUpdateSettings, newOSUpdateSettings fleet.AppleOSUpdateSettings, ) error { if oldOSUpdateSettings.MinimumVersion.Value != newOSUpdateSettings.MinimumVersion.Value || oldOSUpdateSettings.Deadline.Value != newOSUpdateSettings.Deadline.Value { - if license.IsPremium() { + if lic.IsPremium() { if err := svc.EnterpriseOverrides.MDMAppleEditedAppleOSUpdates(ctx, nil, appleDevice, newOSUpdateSettings); err != nil { return ctxerr.Wrap(ctx, err, "update DDM profile after Apple OS updates change") } @@ -1174,33 +1176,33 @@ func (svc *Service) HasCustomSetupAssistantConfigurationWebURL(ctx context.Conte func (svc *Service) validateMDM( ctx context.Context, - license *fleet.LicenseInfo, + lic *fleet.LicenseInfo, oldMdm *fleet.MDM, mdm *fleet.MDM, invalid *fleet.InvalidArgumentError, ) error { - if mdm.EnableDiskEncryption.Value && !license.IsPremium() { + if mdm.EnableDiskEncryption.Value && !lic.IsPremium() { invalid.Append("macos_settings.enable_disk_encryption", ErrMissingLicense.Error()) } - if mdm.MacOSSetup.MacOSSetupAssistant.Value != "" && oldMdm.MacOSSetup.MacOSSetupAssistant.Value != mdm.MacOSSetup.MacOSSetupAssistant.Value && !license.IsPremium() { + if mdm.MacOSSetup.MacOSSetupAssistant.Value != "" && oldMdm.MacOSSetup.MacOSSetupAssistant.Value != mdm.MacOSSetup.MacOSSetupAssistant.Value && !lic.IsPremium() { invalid.Append("macos_setup.macos_setup_assistant", ErrMissingLicense.Error()) } - if mdm.MacOSSetup.EnableReleaseDeviceManually.Value && oldMdm.MacOSSetup.EnableReleaseDeviceManually.Value != mdm.MacOSSetup.EnableReleaseDeviceManually.Value && !license.IsPremium() { + if mdm.MacOSSetup.EnableReleaseDeviceManually.Value && oldMdm.MacOSSetup.EnableReleaseDeviceManually.Value != mdm.MacOSSetup.EnableReleaseDeviceManually.Value && !lic.IsPremium() { invalid.Append("macos_setup.enable_release_device_manually", ErrMissingLicense.Error()) } - if mdm.MacOSSetup.BootstrapPackage.Value != "" && oldMdm.MacOSSetup.BootstrapPackage.Value != mdm.MacOSSetup.BootstrapPackage.Value && !license.IsPremium() { + if mdm.MacOSSetup.BootstrapPackage.Value != "" && oldMdm.MacOSSetup.BootstrapPackage.Value != mdm.MacOSSetup.BootstrapPackage.Value && !lic.IsPremium() { invalid.Append("macos_setup.bootstrap_package", ErrMissingLicense.Error()) } - if mdm.MacOSSetup.EnableEndUserAuthentication && oldMdm.MacOSSetup.EnableEndUserAuthentication != mdm.MacOSSetup.EnableEndUserAuthentication && !license.IsPremium() { + if mdm.MacOSSetup.EnableEndUserAuthentication && oldMdm.MacOSSetup.EnableEndUserAuthentication != mdm.MacOSSetup.EnableEndUserAuthentication && !lic.IsPremium() { invalid.Append("macos_setup.enable_end_user_authentication", ErrMissingLicense.Error()) } - if mdm.MacOSSetup.ManualAgentInstall.Valid && oldMdm.MacOSSetup.ManualAgentInstall.Value != mdm.MacOSSetup.ManualAgentInstall.Value && !license.IsPremium() { + if mdm.MacOSSetup.ManualAgentInstall.Valid && oldMdm.MacOSSetup.ManualAgentInstall.Value != mdm.MacOSSetup.ManualAgentInstall.Value && !lic.IsPremium() { invalid.Append("macos_setup.manual_agent_install", ErrMissingLicense.Error()) } - if mdm.WindowsMigrationEnabled && !license.IsPremium() { + if mdm.WindowsMigrationEnabled && !lic.IsPremium() { invalid.Append("windows_migration_enabled", ErrMissingLicense.Error()) } - if mdm.EnableTurnOnWindowsMDMManually && !license.IsPremium() { + if mdm.EnableTurnOnWindowsMDMManually && !lic.IsPremium() { invalid.Append("enable_turn_on_windows_mdm_manually", ErrMissingLicense.Error()) } @@ -1298,7 +1300,7 @@ func (svc *Service) validateMDM( updatingIPadOSVersion || updatingIPadOSDeadline { // TODO: Should we validate MDM configured on here too? - if !license.IsPremium() { + if !lic.IsPremium() { invalid.Append("macos_updates.minimum_version", ErrMissingLicense.Error()) return nil } @@ -1318,7 +1320,7 @@ func (svc *Service) validateMDM( if updatingWindowsUpdates { // TODO: Should we validate MDM configured on here too? - if !license.IsPremium() { + if !lic.IsPremium() { invalid.Append("windows_updates.deadline_days", ErrMissingLicense.Error()) return nil } @@ -1330,7 +1332,7 @@ func (svc *Service) validateMDM( // EndUserAuthentication // only validate SSO settings if they changed if mdm.EndUserAuthentication.SSOProviderSettings != oldMdm.EndUserAuthentication.SSOProviderSettings { - if !license.IsPremium() { + if !lic.IsPremium() { invalid.Append("end_user_authentication", ErrMissingLicense.Error()) return nil } @@ -1366,7 +1368,7 @@ func (svc *Service) validateMDM( // TODO: Should we validate MDM configured on here too? if mdm.MacOSMigration.Enable { - if !license.IsPremium() { + if !lic.IsPremium() { invalid.Append("macos_migration.enable", ErrMissingLicense.Error()) return nil } @@ -1423,7 +1425,7 @@ func (svc *Service) validateABMAssignments( ctx context.Context, mdm, oldMdm *fleet.MDM, invalid *fleet.InvalidArgumentError, - license *fleet.LicenseInfo, + lic *fleet.LicenseInfo, ) ([]*fleet.ABMToken, error) { if mdm.DeprecatedAppleBMDefaultTeam != "" && mdm.AppleBusinessManager.Set && mdm.AppleBusinessManager.Valid { invalid.Append("mdm.apple_bm_default_team", fleet.AppleABMDefaultTeamDeprecatedMessage) @@ -1431,7 +1433,7 @@ func (svc *Service) validateABMAssignments( } if name := mdm.DeprecatedAppleBMDefaultTeam; name != "" && name != oldMdm.DeprecatedAppleBMDefaultTeam { - if !license.IsPremium() { + if !lic.IsPremium() { invalid.Append("mdm.apple_bm_default_team", ErrMissingLicense.Error()) return nil, nil } @@ -1464,7 +1466,7 @@ func (svc *Service) validateABMAssignments( } if mdm.AppleBusinessManager.Set && len(mdm.AppleBusinessManager.Value) > 0 { - if !license.IsPremium() { + if !lic.IsPremium() { invalid.Append("mdm.apple_business_manager", ErrMissingLicense.Error()) return nil, nil } @@ -1525,14 +1527,14 @@ func (svc *Service) validateVPPAssignments( ctx context.Context, volumePurchasingProgramInfo []fleet.MDMAppleVolumePurchasingProgramInfo, invalid *fleet.InvalidArgumentError, - license *fleet.LicenseInfo, + lic *fleet.LicenseInfo, ) (map[uint][]uint, error) { // Allow clearing VPP assignments in free and premium. if len(volumePurchasingProgramInfo) == 0 { return nil, nil } - if !license.IsPremium() { + if !lic.IsPremium() { invalid.Append("mdm.volume_purchasing_program", ErrMissingLicense.Error()) return nil, nil } @@ -1622,7 +1624,7 @@ func validateSSOProviderSettings(incoming, existing fleet.SSOProviderSettings, i } } -func validateSSOSettings(p fleet.AppConfig, existing *fleet.AppConfig, invalid *fleet.InvalidArgumentError, license *fleet.LicenseInfo) { +func validateSSOSettings(p fleet.AppConfig, existing *fleet.AppConfig, invalid *fleet.InvalidArgumentError, lic *fleet.LicenseInfo) { if p.SSOSettings != nil && p.SSOSettings.EnableSSO { var existingSSOProviderSettings fleet.SSOProviderSettings @@ -1631,7 +1633,7 @@ func validateSSOSettings(p fleet.AppConfig, existing *fleet.AppConfig, invalid * } validateSSOProviderSettings(p.SSOSettings.SSOProviderSettings, existingSSOProviderSettings, invalid) - if !license.IsPremium() { + if !lic.IsPremium() { if p.SSOSettings.EnableJITProvisioning { invalid.Append("enable_jit_provisioning", ErrMissingLicense.Error()) } diff --git a/server/service/apple_mdm.go b/server/service/apple_mdm.go index a4752919f0..f7164bdb8e 100644 --- a/server/service/apple_mdm.go +++ b/server/service/apple_mdm.go @@ -1793,7 +1793,7 @@ func (r mdmAppleEnrollResponse) HijackRender(ctx context.Context, w http.Respons w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusForbidden) if err := json.NewEncoder(w).Encode(r.SoftwareUpdateRequired); err != nil { - endpoint_utils.EncodeError(ctx, ctxerr.New(ctx, "failed to encode software update required"), w) + encodeError(ctx, ctxerr.New(ctx, "failed to encode software update required"), w) } return } diff --git a/server/service/endpoint_middleware.go b/server/service/endpoint_middleware.go index 40e2a892b4..92fd645a6c 100644 --- a/server/service/endpoint_middleware.go +++ b/server/service/endpoint_middleware.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/fleetdm/fleet/v4/server/contexts/certserial" + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/contexts/logging" "github.com/fleetdm/fleet/v4/server/fleet" middleware_log "github.com/fleetdm/fleet/v4/server/service/middleware/log" @@ -104,6 +105,10 @@ func authenticatedDevice(svc fleet.Service, logger log.Logger, next endpoint.End } ctx = hostctx.NewContext(ctx, host) + // Register host as error context provider for ctxerr enrichment + hostProvider := &hostctx.HostAttributeProvider{Host: host} + ctx = ctxerr.AddErrorContextProvider(ctx, hostProvider) + instrumentHostLogger(ctx, host.ID) if ac, ok := authz_ctx.FromContext(ctx); ok { ac.SetAuthnMethod(authnMethod) @@ -151,6 +156,10 @@ func authenticatedHost(svc fleet.Service, logger log.Logger, next endpoint.Endpo } ctx = hostctx.NewContext(ctx, host) + // Register host as error context provider for ctxerr enrichment + hostProvider := &hostctx.HostAttributeProvider{Host: host} + ctx = ctxerr.AddErrorContextProvider(ctx, hostProvider) + instrumentHostLogger(ctx, host.ID) if ac, ok := authz_ctx.FromContext(ctx); ok { ac.SetAuthnMethod(authz_ctx.AuthnHostToken) @@ -193,6 +202,10 @@ func authenticatedOrbitHost( } ctx = hostctx.NewContext(ctx, host) + // Register host as error context provider for ctxerr enrichment + hostProvider := &hostctx.HostAttributeProvider{Host: host} + ctx = ctxerr.AddErrorContextProvider(ctx, hostProvider) + instrumentHostLogger(ctx, host.ID) if ac, ok := authz_ctx.FromContext(ctx); ok { ac.SetAuthnMethod(authz_ctx.AuthnOrbitToken) diff --git a/server/service/endpoint_middleware_test.go b/server/service/endpoint_middleware_test.go index 8414069c4f..a039f1fefa 100644 --- a/server/service/endpoint_middleware_test.go +++ b/server/service/endpoint_middleware_test.go @@ -11,7 +11,6 @@ import ( "github.com/fleetdm/fleet/v4/server/mock" "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/service/middleware/auth" - "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" kitlog "github.com/go-kit/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -188,7 +187,7 @@ func TestAuthenticatedHost(t *testing.T) { r := &testNodeKeyRequest{NodeKey: tt.nodeKey} _, err := endpoint(ctx, r) if tt.shouldErr { - assert.IsType(t, &endpoint_utils.OsqueryError{}, err) + assert.IsType(t, &OsqueryError{}, err) } else { assert.Nil(t, err) } diff --git a/server/service/endpoint_utils.go b/server/service/endpoint_utils.go index 5ed209d251..31bc6e57b2 100644 --- a/server/service/endpoint_utils.go +++ b/server/service/endpoint_utils.go @@ -12,6 +12,7 @@ import ( "github.com/fleetdm/fleet/v4/server/contexts/capabilities" "github.com/fleetdm/fleet/v4/server/fleet" + platform_http "github.com/fleetdm/fleet/v4/server/platform/http" "github.com/fleetdm/fleet/v4/server/service/middleware/auth" eu "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" "github.com/go-kit/kit/endpoint" @@ -21,7 +22,31 @@ import ( ) func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc { - return eu.MakeDecoder(iface, jsonDecode, parseCustomTags, isBodyDecoder, decodeBody) + return eu.MakeDecoder(iface, jsonDecode, parseCustomTags, isBodyDecoder, decodeBody, fleetQueryDecoder) +} + +// fleetQueryDecoder handles fleet-specific query parameter decoding, such as +// converting the order_direction string to the fleet.OrderDirection int type. +func fleetQueryDecoder(queryTagName, queryVal string, field reflect.Value) (bool, error) { + // Only handle int fields for order_direction + if field.Kind() != reflect.Int { + return false, nil + } + switch queryTagName { + case "order_direction", "inherited_order_direction": + var direction int + switch queryVal { + case "desc": + direction = int(fleet.OrderDescending) + case "asc": + direction = int(fleet.OrderAscending) + default: + return false, &fleet.BadRequestError{Message: "unknown order_direction: " + queryVal} + } + field.SetInt(int64(direction)) + return true, nil + } + return false, nil } // A value that implements bodyDecoder takes control of decoding the request body. @@ -100,7 +125,7 @@ type endpointer struct { func (e *endpointer) CallHandlerFunc(f handlerFunc, ctx context.Context, request interface{}, svc interface{}, -) (fleet.Errorer, error) { +) (platform_http.Errorer, error) { return f(ctx, request, svc.(fleet.Service)) } diff --git a/server/service/endpoint_utils_test.go b/server/service/endpoint_utils_test.go index ec0009b73a..12a402af35 100644 --- a/server/service/endpoint_utils_test.go +++ b/server/service/endpoint_utils_test.go @@ -290,7 +290,7 @@ func TestEndpointer(t *testing.T) { auth.SetRequestsContexts(svc), ), kithttp.ServerErrorHandler(&endpoint_utils.ErrorHandler{Logger: kitlog.NewNopLogger()}), - kithttp.ServerErrorEncoder(endpoint_utils.EncodeError), + kithttp.ServerErrorEncoder(fleetErrorEncoder), kithttp.ServerAfter( kithttp.SetContentType("application/json; charset=utf-8"), log.LogRequestEnd(kitlog.NewNopLogger()), @@ -410,7 +410,7 @@ func TestEndpointerCustomMiddleware(t *testing.T) { auth.SetRequestsContexts(svc), ), kithttp.ServerErrorHandler(&endpoint_utils.ErrorHandler{Logger: kitlog.NewNopLogger()}), - kithttp.ServerErrorEncoder(endpoint_utils.EncodeError), + kithttp.ServerErrorEncoder(fleetErrorEncoder), kithttp.ServerAfter( kithttp.SetContentType("application/json; charset=utf-8"), log.LogRequestEnd(kitlog.NewNopLogger()), diff --git a/server/service/handler.go b/server/service/handler.go index 7b264f36b5..1c336b1c4a 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -110,7 +110,7 @@ func MakeHandler( auth.SetRequestsContexts(svc), ), kithttp.ServerErrorHandler(&endpoint_utils.ErrorHandler{Logger: logger}), - kithttp.ServerErrorEncoder(endpoint_utils.EncodeError), + kithttp.ServerErrorEncoder(fleetErrorEncoder), kithttp.ServerAfter( kithttp.SetContentType("application/json; charset=utf-8"), log.LogRequestEnd(logger), diff --git a/server/service/hosts.go b/server/service/hosts.go index 8a8ea0b220..ca79e16fb1 100644 --- a/server/service/hosts.go +++ b/server/service/hosts.go @@ -30,7 +30,6 @@ import ( "github.com/fleetdm/fleet/v4/server/mdm/assets" mdmlifecycle "github.com/fleetdm/fleet/v4/server/mdm/lifecycle" "github.com/fleetdm/fleet/v4/server/ptr" - "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" "github.com/fleetdm/fleet/v4/server/worker" "github.com/go-kit/log/level" "github.com/gocarina/gocsv" @@ -2319,7 +2318,7 @@ func (r hostsReportResponse) HijackRender(ctx context.Context, w http.ResponseWr var buf bytes.Buffer if err := gocsv.Marshal(r.Hosts, &buf); err != nil { logging.WithErr(ctx, err) - endpoint_utils.EncodeError(ctx, ctxerr.New(ctx, "failed to generate CSV file"), w) + encodeError(ctx, ctxerr.New(ctx, "failed to generate CSV file"), w) return } @@ -2331,7 +2330,7 @@ func (r hostsReportResponse) HijackRender(ctx context.Context, w http.ResponseWr recs, err := csv.NewReader(&buf).ReadAll() if err != nil { logging.WithErr(ctx, err) - endpoint_utils.EncodeError(ctx, ctxerr.New(ctx, "failed to generate CSV file"), w) + encodeError(ctx, ctxerr.New(ctx, "failed to generate CSV file"), w) return } @@ -2352,7 +2351,7 @@ func (r hostsReportResponse) HijackRender(ctx context.Context, w http.ResponseWr // duplicating the list of columns from the Host's struct tags to a // map and keep this in sync, for what is essentially a programmer // mistake that should be caught and corrected early. - endpoint_utils.EncodeError(ctx, &fleet.BadRequestError{Message: fmt.Sprintf("invalid column name: %q", col)}, w) + encodeError(ctx, &fleet.BadRequestError{Message: fmt.Sprintf("invalid column name: %q", col)}, w) return } outRows[i] = append(outRows[i], rec[colIx]) diff --git a/server/service/middleware/auth/auth.go b/server/service/middleware/auth/auth.go index c96357c558..6393476e0f 100644 --- a/server/service/middleware/auth/auth.go +++ b/server/service/middleware/auth/auth.go @@ -9,6 +9,8 @@ import ( "github.com/fleetdm/fleet/v4/ee/server/service/hostidentity/httpsig" "github.com/fleetdm/fleet/v4/server/contexts/authz" + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" + "github.com/fleetdm/fleet/v4/server/contexts/logging" "github.com/fleetdm/fleet/v4/server/contexts/token" "github.com/fleetdm/fleet/v4/server/contexts/viewer" "github.com/fleetdm/fleet/v4/server/fleet" @@ -76,6 +78,10 @@ func AuthenticatedUser(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpo } ctx = viewer.NewContext(ctx, *v) + // Register viewer as error context provider for ctxerr enrichment + ctx = ctxerr.AddErrorContextProvider(ctx, v) + // Register viewer as user emailer for logging + ctx = logging.WithUserEmailer(ctx, v) if ac, ok := authz.FromContext(ctx); ok { ac.SetAuthnMethod(authz.AuthnUserToken) } @@ -123,6 +129,10 @@ func AuthenticatedUserMiddleware(svc fleet.Service, errHandler errorHandler, nex } ctx := viewer.NewContext(r.Context(), *v) + // Register viewer as error context provider for ctxerr enrichment + ctx = ctxerr.AddErrorContextProvider(ctx, v) + // Register viewer as user emailer for logging + ctx = logging.WithUserEmailer(ctx, v) if ac, ok := authz.FromContext(r.Context()); ok { ac.SetAuthnMethod(authz.AuthnUserToken) } diff --git a/server/service/middleware/auth/http_auth.go b/server/service/middleware/auth/http_auth.go index 1034bc8ec1..927d6cc256 100644 --- a/server/service/middleware/auth/http_auth.go +++ b/server/service/middleware/auth/http_auth.go @@ -4,6 +4,7 @@ import ( "context" "net/http" + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/contexts/logging" "github.com/fleetdm/fleet/v4/server/contexts/token" "github.com/fleetdm/fleet/v4/server/contexts/viewer" @@ -20,6 +21,10 @@ func SetRequestsContexts(svc fleet.Service) kithttp.RequestFunc { v, err := AuthViewer(ctx, string(bearer), svc) if err == nil { ctx = viewer.NewContext(ctx, *v) + // Register viewer as error context provider for ctxerr enrichment + ctx = ctxerr.AddErrorContextProvider(ctx, v) + // Register viewer as user emailer for logging + ctx = logging.WithUserEmailer(ctx, v) } } diff --git a/server/service/middleware/authzcheck/authzcheck.go b/server/service/middleware/authzcheck/authzcheck.go index d2baa022b4..27fe2dc112 100644 --- a/server/service/middleware/authzcheck/authzcheck.go +++ b/server/service/middleware/authzcheck/authzcheck.go @@ -8,9 +8,8 @@ import ( "context" "errors" - "github.com/fleetdm/fleet/v4/server/authz" authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz" - "github.com/fleetdm/fleet/v4/server/fleet" + platform_http "github.com/fleetdm/fleet/v4/server/platform/http" "github.com/go-kit/kit/endpoint" ) @@ -32,13 +31,13 @@ func (m *Middleware) AuthzCheck() endpoint.Middleware { // If authentication check failed, return that error (so that we log // appropriately). - var authFailedError *fleet.AuthFailedError - var authRequiredError *fleet.AuthRequiredError - var authHeaderRequiredError *fleet.AuthHeaderRequiredError + var authFailedError *platform_http.AuthFailedError + var authRequiredError *platform_http.AuthRequiredError + var authHeaderRequiredError *platform_http.AuthHeaderRequiredError if errors.As(err, &authFailedError) || errors.As(err, &authRequiredError) || errors.As(err, &authHeaderRequiredError) || - errors.Is(err, fleet.ErrPasswordResetRequired) { + errors.Is(err, platform_http.ErrPasswordResetRequired) { return nil, err } @@ -52,7 +51,7 @@ func (m *Middleware) AuthzCheck() endpoint.Middleware { // marshal to a generic error and log that the check was missed. if !authzctx.Checked() { // Getting to here means there is an authorization-related bug in our code. - return nil, authz.CheckMissingWithResponse(response) + return nil, platform_http.CheckMissingWithResponse(response) } return response, err diff --git a/server/service/middleware/authzcheck/authzcheck_test.go b/server/service/middleware/authzcheck/authzcheck_test.go index d95aafdc86..89fd2b86f0 100644 --- a/server/service/middleware/authzcheck/authzcheck_test.go +++ b/server/service/middleware/authzcheck/authzcheck_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/fleetdm/fleet/v4/server/contexts/authz" - "github.com/fleetdm/fleet/v4/server/fleet" + platform_http "github.com/fleetdm/fleet/v4/server/platform/http" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -33,7 +33,7 @@ func TestAuthzCheckAuthFailed(t *testing.T) { checker := NewMiddleware() check := func(ctx context.Context, req interface{}) (interface{}, error) { - return nil, fleet.NewAuthFailedError("failed") + return nil, platform_http.NewAuthFailedError("failed") } check = checker.AuthzCheck()(check) @@ -48,7 +48,7 @@ func TestAuthzCheckAuthRequired(t *testing.T) { checker := NewMiddleware() check := func(ctx context.Context, req interface{}) (interface{}, error) { - return nil, fleet.NewAuthRequiredError("required") + return nil, platform_http.NewAuthRequiredError("required") } check = checker.AuthzCheck()(check) diff --git a/server/service/middleware/endpoint_utils/endpoint_utils.go b/server/service/middleware/endpoint_utils/endpoint_utils.go index 129ba14e44..2646c5b4ec 100644 --- a/server/service/middleware/endpoint_utils/endpoint_utils.go +++ b/server/service/middleware/endpoint_utils/endpoint_utils.go @@ -18,7 +18,7 @@ import ( "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/contexts/license" "github.com/fleetdm/fleet/v4/server/contexts/logging" - "github.com/fleetdm/fleet/v4/server/fleet" + platform_http "github.com/fleetdm/fleet/v4/server/platform/http" "github.com/fleetdm/fleet/v4/server/service/middleware/authzcheck" "github.com/fleetdm/fleet/v4/server/service/middleware/ratelimit" "github.com/go-kit/kit/endpoint" @@ -86,7 +86,7 @@ func BadRequestErr(publicMsg string, internalErr error) error { if errors.As(internalErr, &opErr) { return fmt.Errorf(publicMsg+", internal: %w", internalErr) } - return &fleet.BadRequestError{ + return &platform_http.BadRequestError{ Message: publicMsg, InternalErr: internalErr, } @@ -169,7 +169,11 @@ func DecodeURLTagValue(r *http.Request, field reflect.Value, urlTagValue string, return nil } -func DecodeQueryTagValue(r *http.Request, fp fieldPair) error { +// DomainQueryFieldDecoder decodes a query parameter value into the target field. +// It returns true if it handled the field, false if default handling should be used. +type DomainQueryFieldDecoder func(queryTagName, queryVal string, field reflect.Value) (handled bool, err error) + +func DecodeQueryTagValue(r *http.Request, fp fieldPair, customDecoder DomainQueryFieldDecoder) error { queryTagValue, ok := fp.Sf.Tag.Lookup("query") if ok { @@ -185,7 +189,7 @@ func DecodeQueryTagValue(r *http.Request, fp fieldPair) error { if optional { return nil } - return &fleet.BadRequestError{Message: fmt.Sprintf("Param %s is required", queryTagValue)} + return &platform_http.BadRequestError{Message: fmt.Sprintf("Param %s is required", queryTagValue)} } field := fp.V if field.Kind() == reflect.Ptr { @@ -193,6 +197,18 @@ func DecodeQueryTagValue(r *http.Request, fp fieldPair) error { field.Set(reflect.New(field.Type().Elem())) field = field.Elem() } + + // Try custom decoder first if provided + if customDecoder != nil { + handled, err := customDecoder(queryTagValue, queryVal, field) + if err != nil { + return err + } + if handled { + return nil + } + } + switch field.Kind() { case reflect.String: field.SetString(queryVal) @@ -211,22 +227,9 @@ func DecodeQueryTagValue(r *http.Request, fp fieldPair) error { case reflect.Bool: field.SetBool(queryVal == "1" || queryVal == "true") case reflect.Int: - queryValInt := 0 - switch queryTagValue { - case "order_direction", "inherited_order_direction": - switch queryVal { - case "desc": - queryValInt = int(fleet.OrderDescending) - case "asc": - queryValInt = int(fleet.OrderAscending) - default: - return &fleet.BadRequestError{Message: "unknown order_direction: " + queryVal} - } - default: - queryValInt, err = strconv.Atoi(queryVal) - if err != nil { - return BadRequestErr("parsing int from query", err) - } + queryValInt, err := strconv.Atoi(queryVal) + if err != nil { + return BadRequestErr("parsing int from query", err) } field.SetInt(int64(queryValInt)) default: @@ -297,17 +300,17 @@ func (h *ErrorHandler) Handle(ctx context.Context, err error) { logger = log.With(logger, "took", time.Since(startTime)) } - var ewi fleet.ErrWithInternal + var ewi platform_http.ErrWithInternal if errors.As(err, &ewi) { logger = log.With(logger, "internal", ewi.Internal()) } - var ewlf fleet.ErrWithLogFields + var ewlf platform_http.ErrWithLogFields if errors.As(err, &ewlf) { logger = log.With(logger, ewlf.LogFields()...) } - var uuider fleet.ErrorUUIDer + var uuider platform_http.ErrorUUIDer if errors.As(err, &uuider) { logger = log.With(logger, "uuid", uuider.UUID()) } @@ -338,17 +341,14 @@ type requestValidator interface { } // MakeDecoder creates a decoder for the type for the struct passed on. If the -// struct has at least 1 json tag it'll unmarshall the body. If the struct has -// a `url` tag with value list_options it'll gather fleet.ListOptions from the -// URL (similarly for host_options, carve_options, user_options that derive -// from the common list_options). Note that these behaviors do not work for embedded structs. +// struct has at least 1 json tag it'll unmarshall the body. Custom `url` tag +// values can be handled by providing a parseCustomTags function. Note that +// these behaviors do not work for embedded structs. // -// Finally, any other `url` tag will be treated as a path variable (of the form +// Any other `url` tag will be treated as a path variable (of the form // /path/{name} in the route's path) from the URL path pattern, and it'll be // decoded and set accordingly. Variables can be optional by setting the tag as // follows: `url:"some-id,optional"`. -// The "list_options" are optional by default and it'll ignore the optional -// portion of the tag. // // If iface implements the RequestDecoder interface, it returns a function that // calls iface.DecodeRequest(ctx, r) - i.e. the value itself fully controls its @@ -357,12 +357,16 @@ type requestValidator interface { // If iface implements the bodyDecoder interface, it calls iface.DecodeBody // after having decoded any non-body fields (such as url and query parameters) // into the struct. +// +// The customQueryDecoder parameter allows services to inject domain-specific +// query parameter decoding logic. func MakeDecoder( iface interface{}, jsonUnmarshal func(body io.Reader, req any) error, parseCustomTags func(urlTagValue string, r *http.Request, field reflect.Value) (bool, error), isBodyDecoder func(reflect.Value) bool, decodeBody func(ctx context.Context, r *http.Request, v reflect.Value, body io.Reader) error, + customQueryDecoder DomainQueryFieldDecoder, ) kithttp.DecodeRequestFunc { if iface == nil { return func(ctx context.Context, r *http.Request) (interface{}, error) { @@ -441,16 +445,16 @@ func MakeDecoder( _, jsonExpected := fp.Sf.Tag.Lookup("json") if jsonExpected && nilBody { - return nil, &fleet.BadRequestError{Message: "Expected JSON Body"} + return nil, &platform_http.BadRequestError{Message: "Expected JSON Body"} } isContentJson := r.Header.Get("Content-Type") == "application/json" isCrossSite := r.Header.Get("Origin") != "" || r.Header.Get("Referer") != "" if jsonExpected && isCrossSite && !isContentJson { - return nil, fleet.NewUserMessageError(errors.New("Expected Content-Type \"application/json\""), http.StatusUnsupportedMediaType) + return nil, platform_http.NewUserMessageError(errors.New("Expected Content-Type \"application/json\""), http.StatusUnsupportedMediaType) } - err = DecodeQueryTagValue(r, fp) + err = DecodeQueryTagValue(r, fp, customQueryDecoder) if err != nil { return nil, err } @@ -471,7 +475,7 @@ func MakeDecoder( return nil, err } if val && !fp.V.IsZero() { - return nil, &fleet.BadRequestError{Message: fmt.Sprintf( + return nil, &platform_http.BadRequestError{Message: fmt.Sprintf( "option %s requires a premium license", fp.Sf.Name, )} @@ -509,17 +513,19 @@ func WriteBrowserSecurityHeaders(w http.ResponseWriter) { } type CommonEndpointer[H any] struct { - EP Endpointer[H] - MakeDecoderFn func(iface interface{}) kithttp.DecodeRequestFunc - EncodeFn kithttp.EncodeResponseFunc - Opts []kithttp.ServerOption - AuthMiddleware endpoint.Middleware - Router *mux.Router - Versions []string + EP Endpointer[H] + MakeDecoderFn func(iface any) kithttp.DecodeRequestFunc + EncodeFn kithttp.EncodeResponseFunc + Opts []kithttp.ServerOption + Router *mux.Router + Versions []string - // CustomMiddleware are middlewares that run before AuthMiddleware. + // AuthMiddleware is a pre-built authentication middleware. + AuthMiddleware endpoint.Middleware + + // CustomMiddleware are middlewares that run before authentication. CustomMiddleware []endpoint.Middleware - // CustomMiddlewareAfterAuth are middlewares that run after AuthMiddleware. + // CustomMiddlewareAfterAuth are middlewares that run after authentication. CustomMiddlewareAfterAuth []endpoint.Middleware startingAtVersion string @@ -529,8 +535,8 @@ type CommonEndpointer[H any] struct { } type Endpointer[H any] interface { - CallHandlerFunc(f H, ctx context.Context, request interface{}, svc interface{}) (fleet.Errorer, error) - Service() interface{} + CallHandlerFunc(f H, ctx context.Context, request any, svc any) (platform_http.Errorer, error) + Service() any } func (e *CommonEndpointer[H]) POST(path string, f H, v interface{}) { @@ -730,6 +736,7 @@ func EncodeCommonResponse( w http.ResponseWriter, response interface{}, jsonMarshal func(w http.ResponseWriter, response interface{}) error, + domainErrorEncoder DomainErrorEncoder, ) error { if cs, ok := response.(cookieSetter); ok { cs.SetCookies(ctx, w) @@ -747,8 +754,8 @@ func EncodeCommonResponse( return err } - if e, ok := response.(fleet.Errorer); ok && e.Error() != nil { - EncodeError(ctx, e.Error(), w) + if e, ok := response.(platform_http.Errorer); ok && e.Error() != nil { + EncodeError(ctx, e.Error(), w, domainErrorEncoder) return nil } diff --git a/server/service/middleware/endpoint_utils/endpoint_utils_test.go b/server/service/middleware/endpoint_utils/endpoint_utils_test.go index 4ba5498f80..145b5b8336 100644 --- a/server/service/middleware/endpoint_utils/endpoint_utils_test.go +++ b/server/service/middleware/endpoint_utils/endpoint_utils_test.go @@ -8,7 +8,7 @@ import ( "testing" authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz" - "github.com/fleetdm/fleet/v4/server/fleet" + platform_http "github.com/fleetdm/fleet/v4/server/platform/http" "github.com/go-kit/kit/endpoint" kithttp "github.com/go-kit/kit/transport/http" "github.com/gorilla/mux" @@ -16,7 +16,7 @@ import ( ) // testHandlerFunc is a handler function type used for testing. -type testHandlerFunc func(ctx context.Context, request any, svc any) (fleet.Errorer, error) +type testHandlerFunc func(ctx context.Context, request any) (platform_http.Errorer, error) func TestCustomMiddlewareAfterAuth(t *testing.T) { var ( @@ -82,7 +82,7 @@ func TestCustomMiddlewareAfterAuth(t *testing.T) { }, Router: r, } - ce.handleEndpoint("/", func(ctx context.Context, request interface{}, svc any) (fleet.Errorer, error) { + ce.handleEndpoint("/", func(ctx context.Context, request any) (platform_http.Errorer, error) { fmt.Printf("handler\n") return nopResponse{}, nil }, nil, "GET") @@ -116,10 +116,10 @@ func (n nopResponse) Error() error { type nopEP struct{} -func (n nopEP) CallHandlerFunc(_ testHandlerFunc, _ context.Context, _ any, _ any) (fleet.Errorer, error) { - return nopResponse{}, nil +func (n nopEP) CallHandlerFunc(f testHandlerFunc, ctx context.Context, request any, svc any) (platform_http.Errorer, error) { + return f(ctx, request) } -func (n nopEP) Service() interface{} { +func (n nopEP) Service() any { return nil } diff --git a/server/service/middleware/endpoint_utils/transport_error.go b/server/service/middleware/endpoint_utils/transport_error.go index 7fdb508805..1c1d65094b 100644 --- a/server/service/middleware/endpoint_utils/transport_error.go +++ b/server/service/middleware/endpoint_utils/transport_error.go @@ -4,13 +4,12 @@ import ( "context" "encoding/json" "errors" - "fmt" "net" "net/http" "strconv" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - "github.com/fleetdm/fleet/v4/server/fleet" + platform_http "github.com/fleetdm/fleet/v4/server/platform/http" kithttp "github.com/go-kit/kit/transport/http" "github.com/go-sql-driver/mysql" ) @@ -18,6 +17,11 @@ import ( // ErrBadRoute is used for mux errors var ErrBadRoute = errors.New("bad route") +// DomainErrorEncoder handles domain-specific error encoding. +// It returns true if it handled the error, false if default handling should be used. +// The encoder should write the appropriate status code and response body. +type DomainErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter, enc *json.Encoder, jsonErr *JsonError) (handled bool) + type JsonError struct { Message string `json:"message"` Code int `json:"code,omitempty"` @@ -67,8 +71,10 @@ type conflictErrorInterface interface { IsConflict() bool } -// EncodeError encodes error and status header to the client -func EncodeError(ctx context.Context, err error, w http.ResponseWriter) { +// EncodeError encodes error and status header to the client. +// The domainEncoder parameter allows services to inject domain-specific error +// handling. If nil, only generic error handling is performed. +func EncodeError(ctx context.Context, err error, w http.ResponseWriter, domainEncoder DomainErrorEncoder) { ctxerr.Handle(ctx, err) origErr := err @@ -78,7 +84,7 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) { err = ctxerr.Cause(err) var uuid string - if uuidErr, ok := err.(fleet.ErrorUUIDer); ok { + if uuidErr, ok := err.(platform_http.ErrorUUIDer); ok { uuid = uuidErr.UUID() } @@ -86,6 +92,13 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) { UUID: uuid, } + // Try domain-specific error encoder first + if domainEncoder != nil { + if handled := domainEncoder(ctx, err, w, enc, &jsonErr); handled { + return + } + } + switch e := err.(type) { case validationErrorInterface: if statusErr, ok := e.(interface{ Status() int }); ok { @@ -99,36 +112,6 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) { jsonErr.Message = "Permission Denied" jsonErr.Errors = e.PermissionError() w.WriteHeader(http.StatusForbidden) - case MailError: - jsonErr.Message = "Mail Error" - jsonErr.Errors = e.MailError() - w.WriteHeader(http.StatusInternalServerError) - case *OsqueryError: - // osquery expects to receive the node_invalid key when a TLS - // request provides an invalid node_key for authentication. It - // doesn't use the error message provided, but we provide this - // for debugging purposes (and perhaps osquery will use this - // error message in the future). - - errMap := map[string]interface{}{ - "error": e.Error(), - "uuid": uuid, - } - if e.NodeInvalid() { //nolint:gocritic // ignore ifElseChain - w.WriteHeader(http.StatusUnauthorized) - errMap["node_invalid"] = true - } else if e.Status() != 0 { - w.WriteHeader(e.Status()) - } else { - // TODO: osqueryError is not always the result of an internal error on - // our side, it is also used to represent a client error (invalid data, - // e.g. malformed json, carve too large, etc., so 4xx), are we returning - // a 500 because of some osquery-specific requirement? - w.WriteHeader(http.StatusInternalServerError) - } - - enc.Encode(errMap) //nolint:errcheck - return case NotFoundErrorInterface: jsonErr.Message = "Resource Not Found" jsonErr.Errors = baseError(e.Error()) @@ -153,7 +136,7 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) { statusCode = http.StatusConflict } w.WriteHeader(statusCode) - case *fleet.Error: + case *platform_http.Error: jsonErr.Message = e.Error() jsonErr.Code = e.Code w.WriteHeader(http.StatusUnprocessableEntity) @@ -168,7 +151,7 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) { enc.Encode(jsonErr) //nolint:errcheck return } - if fleet.IsForeignKey(err) { + if platform_http.IsForeignKey(err) { jsonErr.Message = "Validation Failed" jsonErr.Errors = baseError(err.Error()) w.WriteHeader(http.StatusUnprocessableEntity) @@ -186,14 +169,14 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) { // See header documentation // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After) - var ewra fleet.ErrWithRetryAfter + var ewra platform_http.ErrWithRetryAfter if errors.As(err, &ewra) { w.Header().Add("Retry-After", strconv.Itoa(ewra.RetryAfter())) } msg := err.Error() reason := err.Error() - var ume *fleet.UserMessageError + var ume *platform_http.UserMessageError if errors.As(err, &ume) { if text := http.StatusText(status); text != "" { msg = text @@ -208,53 +191,3 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) { enc.Encode(jsonErr) //nolint:errcheck } - -// MailError is set when an error performing mail operations -type MailError struct { - Message string -} - -func (e MailError) Error() string { - return fmt.Sprintf("a mail error occurred: %s", e.Message) -} - -func (e MailError) MailError() []map[string]string { - return []map[string]string{ - { - "name": "base", - "reason": e.Message, - }, - } -} - -// OsqueryError is the error returned to osquery agents. -type OsqueryError struct { - message string - nodeInvalid bool - StatusCode int - fleet.ErrorWithUUID -} - -var _ fleet.ErrorUUIDer = (*OsqueryError)(nil) - -// Error implements the error interface. -func (e *OsqueryError) Error() string { - return e.message -} - -// NodeInvalid returns whether the error returned to osquery -// should contain the node_invalid property. -func (e *OsqueryError) NodeInvalid() bool { - return e.nodeInvalid -} - -func (e *OsqueryError) Status() int { - return e.StatusCode -} - -func NewOsqueryError(message string, nodeInvalid bool) *OsqueryError { - return &OsqueryError{ - message: message, - nodeInvalid: nodeInvalid, - } -} diff --git a/server/service/middleware/endpoint_utils/transport_error_test.go b/server/service/middleware/endpoint_utils/transport_error_test.go index 002d52b62a..10a9c796e0 100644 --- a/server/service/middleware/endpoint_utils/transport_error_test.go +++ b/server/service/middleware/endpoint_utils/transport_error_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "testing" - "github.com/fleetdm/fleet/v4/server/fleet" + platform_http "github.com/fleetdm/fleet/v4/server/platform/http" "github.com/stretchr/testify/assert" ) @@ -25,7 +25,7 @@ type newAndExciting struct{} func (newAndExciting) Error() string { return "" } type notFoundError struct { - fleet.ErrorWithUUID + platform_http.ErrorWithUUID } func (e *notFoundError) Error() string { @@ -36,6 +36,32 @@ func (e *notFoundError) IsNotFound() bool { return true } +// validationError is a test implementation of validationErrorInterface. +type validationError struct { + errors []map[string]string +} + +func (e validationError) Error() string { + return "validation failed" +} + +func (e validationError) Invalid() []map[string]string { + return e.errors +} + +// permissionError is a test implementation of permissionErrorInterface. +type permissionError struct { + message string +} + +func (e permissionError) Error() string { + return e.message +} + +func (e permissionError) PermissionError() []map[string]string { + return nil +} + func TestHandlesErrorsCode(t *testing.T) { errorTests := []struct { name string @@ -44,12 +70,12 @@ func TestHandlesErrorsCode(t *testing.T) { }{ { "validation", - fleet.NewInvalidArgumentError("a", "b"), + validationError{errors: []map[string]string{{"name": "a", "reason": "b"}}}, http.StatusUnprocessableEntity, }, { "permission", - fleet.NewPermissionError("a"), + permissionError{message: "a"}, http.StatusForbidden, }, { @@ -57,21 +83,6 @@ func TestHandlesErrorsCode(t *testing.T) { foreignKeyError{}, http.StatusUnprocessableEntity, }, - { - "mail error", - MailError{}, - http.StatusInternalServerError, - }, - { - "osquery error - invalid node", - &OsqueryError{nodeInvalid: true}, - http.StatusUnauthorized, - }, - { - "osquery error - valid node", - &OsqueryError{}, - http.StatusInternalServerError, - }, { "data not found", ¬FoundError{}, @@ -84,7 +95,7 @@ func TestHandlesErrorsCode(t *testing.T) { }, { "status coder", - fleet.NewAuthFailedError(""), + platform_http.NewAuthFailedError(""), http.StatusUnauthorized, }, { @@ -97,7 +108,7 @@ func TestHandlesErrorsCode(t *testing.T) { for _, tt := range errorTests { t.Run(tt.name, func(t *testing.T) { recorder := httptest.NewRecorder() - EncodeError(context.Background(), tt.err, recorder) + EncodeError(context.Background(), tt.err, recorder, nil) assert.Equal(t, recorder.Code, tt.code) }) } diff --git a/server/service/orbit.go b/server/service/orbit.go index 0304a1b6ea..14223a9260 100644 --- a/server/service/orbit.go +++ b/server/service/orbit.go @@ -21,7 +21,6 @@ import ( microsoft_mdm "github.com/fleetdm/fleet/v4/server/mdm/microsoft" "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/service/contract" - "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" "github.com/fleetdm/fleet/v4/server/worker" "github.com/go-kit/log/level" ) @@ -65,7 +64,7 @@ func (r EnrollOrbitResponse) HijackRender(ctx context.Context, w http.ResponseWr enc.SetIndent("", " ") if err := enc.Encode(r); err != nil { - endpoint_utils.EncodeError(ctx, newOsqueryError(fmt.Sprintf("orbit enroll failed: %s", err)), w) + encodeError(ctx, newOsqueryError(fmt.Sprintf("orbit enroll failed: %s", err)), w) } } diff --git a/server/service/osquery.go b/server/service/osquery.go index 3ca9075f01..9cb4505ed5 100644 --- a/server/service/osquery.go +++ b/server/service/osquery.go @@ -25,7 +25,6 @@ import ( "github.com/fleetdm/fleet/v4/server/pubsub" "github.com/fleetdm/fleet/v4/server/service/conditional_access_microsoft_proxy" "github.com/fleetdm/fleet/v4/server/service/contract" - "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" "github.com/fleetdm/fleet/v4/server/service/osquery_utils" kithttp "github.com/go-kit/kit/transport/http" "github.com/go-kit/log" @@ -34,12 +33,12 @@ import ( "golang.org/x/exp/slices" ) -func newOsqueryErrorWithInvalidNode(msg string) *endpoint_utils.OsqueryError { - return endpoint_utils.NewOsqueryError(msg, true) +func newOsqueryErrorWithInvalidNode(msg string) *OsqueryError { + return NewOsqueryError(msg, true) } -func newOsqueryError(msg string) *endpoint_utils.OsqueryError { - return endpoint_utils.NewOsqueryError(msg, false) +func newOsqueryError(msg string) *OsqueryError { + return NewOsqueryError(msg, false) } func (svc *Service) AuthenticateHost(ctx context.Context, nodeKey string) (*fleet.Host, bool, error) { @@ -138,7 +137,7 @@ func (svc *Service) EnrollOsquery(ctx context.Context, enrollSecret, hostIdentif if !canEnroll { deviceCount := "unknown" if lic, _ := license.FromContext(ctx); lic != nil { - deviceCount = strconv.Itoa(lic.DeviceCount) + deviceCount = strconv.Itoa(lic.GetDeviceCount()) } return "", newOsqueryErrorWithInvalidNode(fmt.Sprintf("enroll host failed: maximum number of hosts reached: %s", deviceCount)) } diff --git a/server/service/osquery_test.go b/server/service/osquery_test.go index 98ed477476..b048cb89a9 100644 --- a/server/service/osquery_test.go +++ b/server/service/osquery_test.go @@ -33,7 +33,6 @@ import ( "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/pubsub" "github.com/fleetdm/fleet/v4/server/service/async" - "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" "github.com/fleetdm/fleet/v4/server/service/osquery_utils" "github.com/fleetdm/fleet/v4/server/service/redis_policy_set" "github.com/go-kit/log" @@ -992,7 +991,7 @@ func TestSubmitResultLogsFail(t *testing.T) { // Expect an error when unable to write to logging destination. err = svc.SubmitResultLogs(ctx, results) require.Error(t, err) - assert.Equal(t, http.StatusRequestEntityTooLarge, err.(*endpoint_utils.OsqueryError).Status()) + assert.Equal(t, http.StatusRequestEntityTooLarge, err.(*OsqueryError).Status()) } func TestGetQueryNameAndTeamIDFromResult(t *testing.T) { @@ -2908,7 +2907,7 @@ func TestAuthenticationErrors(t *testing.T) { _, _, err := svc.AuthenticateHost(ctx, "") require.Error(t, err) - require.True(t, err.(*endpoint_utils.OsqueryError).NodeInvalid()) + require.True(t, err.(*OsqueryError).NodeInvalid()) ms.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { return &fleet.Host{ID: 1, HasHostIdentityCert: ptr.Bool(false)}, nil @@ -2929,7 +2928,7 @@ func TestAuthenticationErrors(t *testing.T) { _, _, err = svc.AuthenticateHost(ctx, "foo") require.Error(t, err) - require.True(t, err.(*endpoint_utils.OsqueryError).NodeInvalid()) + require.True(t, err.(*OsqueryError).NodeInvalid()) // return other error ms.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { @@ -2938,7 +2937,7 @@ func TestAuthenticationErrors(t *testing.T) { _, _, err = svc.AuthenticateHost(ctx, "foo") require.NotNil(t, err) - require.False(t, err.(*endpoint_utils.OsqueryError).NodeInvalid()) + require.False(t, err.(*OsqueryError).NodeInvalid()) } func TestGetHostIdentifier(t *testing.T) { diff --git a/server/service/service_appconfig.go b/server/service/service_appconfig.go index f437f0c539..1e9a59dcaf 100644 --- a/server/service/service_appconfig.go +++ b/server/service/service_appconfig.go @@ -12,7 +12,6 @@ import ( "github.com/fleetdm/fleet/v4/server/contexts/viewer" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mail" - "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" ) func (svc *Service) NewAppConfig(ctx context.Context, p fleet.AppConfig) (*fleet.AppConfig, error) { @@ -65,7 +64,7 @@ func (svc *Service) sendTestEmail(ctx context.Context, config *fleet.AppConfig) } if err := mail.Test(svc.mailService, testMail); err != nil { - return endpoint_utils.MailError{Message: err.Error()} + return MailError{Message: err.Error()} } return nil } @@ -83,12 +82,14 @@ func (svc *Service) License(ctx context.Context) (*fleet.LicenseInfo, error) { } } - lic, _ := license.FromContext(ctx) + licChecker, _ := license.FromContext(ctx) + // Type assert to get the concrete type for modification and return + lic, _ := licChecker.(*fleet.LicenseInfo) // Currently we use the presence of Microsoft Compliance Partner settings // (only configured in cloud instances) to determine if a Fleet instance // is a cloud managed instance. - if svc.config.MicrosoftCompliancePartner.IsSet() { + if lic != nil && svc.config.MicrosoftCompliancePartner.IsSet() { lic.ManagedCloud = true } diff --git a/server/service/service_users.go b/server/service/service_users.go index d13c503f70..c88e1b2c48 100644 --- a/server/service/service_users.go +++ b/server/service/service_users.go @@ -31,14 +31,15 @@ func (svc *Service) CreateInitialUser(ctx context.Context, p fleet.UserPayload) } func (svc *Service) NewUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) { - license, _ := license.FromContext(ctx) - if license == nil { + licChecker, _ := license.FromContext(ctx) + lic, _ := licChecker.(*fleet.LicenseInfo) + if lic == nil { return nil, ctxerr.New(ctx, "license not found") } - if err := fleet.ValidateUserRoles(true, p, *license); err != nil { + if err := fleet.ValidateUserRoles(true, p, *lic); err != nil { return nil, ctxerr.Wrap(ctx, err, "validate role") } - if !license.IsPremium() { + if !lic.IsPremium() { p.MFAEnabled = ptr.Bool(false) } diff --git a/server/service/transport.go b/server/service/transport.go index 238c89a21c..51523a72dc 100644 --- a/server/service/transport.go +++ b/server/service/transport.go @@ -16,7 +16,7 @@ import ( ) func encodeResponse(ctx context.Context, w http.ResponseWriter, response interface{}) error { - return endpoint_utils.EncodeCommonResponse(ctx, w, response, jsonMarshal) + return endpoint_utils.EncodeCommonResponse(ctx, w, response, jsonMarshal, FleetErrorEncoder) } func jsonMarshal(w http.ResponseWriter, response interface{}) error { diff --git a/server/service/transport_error.go b/server/service/transport_error.go new file mode 100644 index 0000000000..beef983a64 --- /dev/null +++ b/server/service/transport_error.go @@ -0,0 +1,110 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + platform_http "github.com/fleetdm/fleet/v4/server/platform/http" + "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" +) + +// FleetErrorEncoder handles fleet-specific error encoding for MailError +// and OsqueryError. +func FleetErrorEncoder(ctx context.Context, err error, w http.ResponseWriter, enc *json.Encoder, jsonErr *endpoint_utils.JsonError) bool { + switch e := err.(type) { + case MailError: + jsonErr.Message = "Mail Error" + jsonErr.Errors = []map[string]string{ + { + "name": "base", + "reason": e.Message, + }, + } + w.WriteHeader(http.StatusInternalServerError) + enc.Encode(jsonErr) //nolint:errcheck + return true + + case *OsqueryError: + // osquery expects to receive the node_invalid key when a TLS + // request provides an invalid node_key for authentication. It + // doesn't use the error message provided, but we provide this + // for debugging purposes (and perhaps osquery will use this + // error message in the future). + + errMap := map[string]any{ + "error": e.Error(), + "uuid": jsonErr.UUID, + } + if e.NodeInvalid() { //nolint:gocritic // ignore ifElseChain + w.WriteHeader(http.StatusUnauthorized) + errMap["node_invalid"] = true + } else if e.Status() != 0 { + w.WriteHeader(e.Status()) + } else { + // TODO: osqueryError is not always the result of an internal error on + // our side, it is also used to represent a client error (invalid data, + // e.g. malformed json, carve too large, etc., so 4xx), are we returning + // a 500 because of some osquery-specific requirement? + w.WriteHeader(http.StatusInternalServerError) + } + enc.Encode(errMap) //nolint:errcheck + return true + } + + return false +} + +// MailError is set when an error performing mail operations +type MailError struct { + Message string +} + +func (e MailError) Error() string { + return fmt.Sprintf("a mail error occurred: %s", e.Message) +} + +// OsqueryError is the error returned to osquery agents. +type OsqueryError struct { + message string + nodeInvalid bool + StatusCode int + platform_http.ErrorWithUUID +} + +var _ platform_http.ErrorUUIDer = (*OsqueryError)(nil) + +// Error implements the error interface. +func (e *OsqueryError) Error() string { + return e.message +} + +// NodeInvalid returns whether the error returned to osquery +// should contain the node_invalid property. +func (e *OsqueryError) NodeInvalid() bool { + return e.nodeInvalid +} + +func (e *OsqueryError) Status() int { + return e.StatusCode +} + +func NewOsqueryError(message string, nodeInvalid bool) *OsqueryError { + return &OsqueryError{ + message: message, + nodeInvalid: nodeInvalid, + } +} + +// encodeError is a convenience function that calls endpoint_utils.EncodeError +// with the FleetErrorEncoder. Use this for direct error encoding in handlers. +func encodeError(ctx context.Context, err error, w http.ResponseWriter) { + endpoint_utils.EncodeError(ctx, err, w, FleetErrorEncoder) +} + +// fleetErrorEncoder is an adapter that wraps endpoint_utils.EncodeError with +// FleetErrorEncoder for use as a kithttp.ErrorEncoder. +func fleetErrorEncoder(ctx context.Context, err error, w http.ResponseWriter) { + endpoint_utils.EncodeError(ctx, err, w, FleetErrorEncoder) +} diff --git a/server/service/users.go b/server/service/users.go index a6103eff67..bd5f47a820 100644 --- a/server/service/users.go +++ b/server/service/users.go @@ -433,11 +433,12 @@ func (svc *Service) ModifyUser(ctx context.Context, userID uint, p fleet.UserPay if err := svc.authz.Authorize(ctx, user, fleet.ActionWriteRole); err != nil { return nil, err } - license, _ := license.FromContext(ctx) - if license == nil { + licChecker, _ := license.FromContext(ctx) + lic, _ := licChecker.(*fleet.LicenseInfo) + if lic == nil { return nil, ctxerr.New(ctx, "license not found") } - if err := fleet.ValidateUserRoles(false, p, *license); err != nil { + if err := fleet.ValidateUserRoles(false, p, *lic); err != nil { return nil, ctxerr.Wrap(ctx, err, "validate role") } }