Refactor endpoint_utils for modularization (#36484)

Resolves #37192

Separating generic endpoint_utils middleware logic from domain-specific
business logic. New bounded contexts would share the generic logic and
implement their own domain-specific logic. The two approaches used in
this PR are:
- Use common `platform` types
- Use interfaces

In the next PR we will move `endpointer_utils`, `authzcheck` and
`ratelimit` into `platform` directory.

# Checklist for submitter

- [x] Added changes file

## Testing

- [x] Added/updated tests
- [x] QA'd all new/changed functionality manually



<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Refactor**
* Restructured internal error handling and context management to support
bounded context architecture.
* Improved error context collection and telemetry observability through
a provider-based mechanism.
* Decoupled licensing and authentication concerns into interfaces for
better modularity.

* **Chores**
* Updated internal package dependencies to align with new architectural
boundaries.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Victor Lyuboslavsky 2025-12-31 09:12:00 -06:00 committed by GitHub
parent 360a426224
commit c88cc953fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
47 changed files with 1126 additions and 637 deletions

View file

@ -0,0 +1 @@
Refactored common endpoint_utils package to support bounded contexts inside Fleet codebase.

View file

@ -30,8 +30,11 @@ type PackageTest struct {
forbiddenPkgs []string
}
// ModuleName is the module name for the fleet project.
const ModuleName = "github.com/fleetdm/fleet/v4"
// PackageTest will ignore dependency on this package.
const thisPackage = "github.com/fleetdm/fleet/v4/server/archtest"
const thisPackage = ModuleName + "/server/archtest"
type TestingT interface {
Errorf(format string, args ...any)

View file

@ -4,6 +4,7 @@ import (
"net/http"
"github.com/fleetdm/fleet/v4/server/fleet"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
)
const (
@ -66,27 +67,15 @@ func (e *Forbidden) LogFields() []interface{} {
// CheckMissing is the error to return when no authorization check was performed
// by the service.
type CheckMissing struct {
response interface{}
fleet.ErrorWithUUID
}
//
// Deprecated: Use platform_http.CheckMissing instead. This alias is kept for
// backward compatibility.
type CheckMissing = platform_http.CheckMissing
// CheckMissingWithResponse creates a new error indicating the authorization
// check was missed, and including the response for further analysis by the error
// encoder.
func CheckMissingWithResponse(response interface{}) *CheckMissing {
return &CheckMissing{response: response}
}
func (e *CheckMissing) Error() string {
return ForbiddenErrorMessage
}
func (e *CheckMissing) Internal() string {
return "Missing authorization check"
}
func (e *CheckMissing) Response() interface{} {
return e.response
}
//
// Deprecated: Use platform_http.CheckMissingWithResponse instead. This alias is
// kept for backward compatibility.
var CheckMissingWithResponse = platform_http.CheckMissingWithResponse

View file

@ -16,13 +16,12 @@ import (
"encoding/json"
"errors"
"fmt"
"maps"
"runtime"
"strings"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/host"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
"github.com/getsentry/sentry-go"
"go.elastic.co/apm/v2"
"go.opentelemetry.io/otel/attribute"
@ -121,21 +120,9 @@ func setMetadata(ctx context.Context, data map[string]interface{}) map[string]in
data["timestamp"] = nowFn().Format(time.RFC3339)
if h, ok := host.FromContext(ctx); ok {
data["host"] = map[string]interface{}{
"platform": h.Platform,
"osquery_version": h.OsqueryVersion,
}
}
if v, ok := viewer.FromContext(ctx); ok {
vdata := map[string]interface{}{}
data["viewer"] = vdata
vdata["is_logged_in"] = v.IsLoggedIn()
if v.User != nil {
vdata["sso_enabled"] = v.User.SSOEnabled
}
// Get diagnostic context from all registered providers
for _, provider := range getErrorContextProviders(ctx) {
maps.Copy(data, provider.GetDiagnosticContext())
}
return data
@ -215,7 +202,7 @@ func Wrapf(ctx context.Context, cause error, format string, args ...interface{})
// Cause returns the root error in err's chain.
func Cause(err error) error {
return fleet.Cause(err)
return platform_http.Cause(err)
}
// FleetCause is similar to Cause, but returns the root-most
@ -319,6 +306,9 @@ func Handle(ctx context.Context, err error) {
cause = rootCause
}
// Collect telemetry context from registered providers
telemetryAttrs := collectTelemetryContext(ctx)
// send to OpenTelemetry if there's an active span
if span := trace.SpanFromContext(ctx); span != nil && span.IsRecording() {
// Mark the current span as failed by setting the error status.
@ -333,20 +323,25 @@ func Handle(ctx context.Context, err error) {
attribute.String("exception.stacktrace", strings.Join(cause.Stack(), "\n")),
}
// Add contextual information if available (same as Sentry)
v, _ := viewer.FromContext(ctx)
h, _ := host.FromContext(ctx)
if v.User != nil {
attrs = append(attrs,
// Not sending the email here as it may contain sensitive information (PII).
attribute.Int64("user.id", int64(v.User.ID)), //nolint:gosec
)
} else if h != nil {
attrs = append(attrs,
attribute.String("host.hostname", h.Hostname),
attribute.Int64("host.id", int64(h.ID)), //nolint:gosec
)
// Add contextual information from telemetry providers.
// OpenTelemetry requires typed attributes, so we convert the values to the appropriate type.
for k, v := range telemetryAttrs {
switch val := v.(type) {
case string:
attrs = append(attrs, attribute.String(k, val))
case int:
attrs = append(attrs, attribute.Int64(k, int64(val)))
case int64:
attrs = append(attrs, attribute.Int64(k, val))
case uint:
attrs = append(attrs, attribute.Int64(k, int64(val))) //nolint:gosec
case uint64:
attrs = append(attrs, attribute.Int64(k, int64(val))) //nolint:gosec
case bool:
attrs = append(attrs, attribute.Bool(k, val))
default:
attrs = append(attrs, attribute.String(k, fmt.Sprint(val)))
}
}
span.AddEvent("exception", trace.WithAttributes(attrs...))
@ -357,25 +352,14 @@ func Handle(ctx context.Context, err error) {
// if Sentry is configured, capture the error there
if sentryClient := sentry.CurrentHub().Client(); sentryClient != nil {
// sentry is configured, add contextual information if available
v, _ := viewer.FromContext(ctx)
h, _ := host.FromContext(ctx)
if v.User != nil || h != nil {
// we have a viewer (user) or a host in the context, use this to
// enrich the error with more context
if len(telemetryAttrs) > 0 {
// we have contextual information, use it to enrich the error
ctxHub := sentry.CurrentHub().Clone()
if v.User != nil {
ctxHub.ConfigureScope(func(scope *sentry.Scope) {
scope.SetTag("email", v.User.Email)
scope.SetTag("user_id", fmt.Sprint(v.User.ID))
})
} else if h != nil {
ctxHub.ConfigureScope(func(scope *sentry.Scope) {
scope.SetTag("hostname", h.Hostname)
scope.SetTag("host_id", fmt.Sprint(h.ID))
})
}
ctxHub.ConfigureScope(func(scope *sentry.Scope) {
for k, v := range telemetryAttrs {
scope.SetTag(k, fmt.Sprint(v))
}
})
ctxHub.CaptureException(cause)
} else {
sentry.CaptureException(cause)
@ -387,6 +371,17 @@ func Handle(ctx context.Context, err error) {
}
}
// collectTelemetryContext gathers telemetry context from all registered providers.
func collectTelemetryContext(ctx context.Context) map[string]any {
attrs := make(map[string]any)
for _, provider := range getErrorContextProviders(ctx) {
if telemetry := provider.GetTelemetryContext(); telemetry != nil {
maps.Copy(attrs, telemetry)
}
}
return attrs
}
// Retrieve retrieves an error from the registered error handler
func Retrieve(ctx context.Context) ([]*StoredError, error) {
eh := FromContext(ctx)

View file

@ -29,7 +29,11 @@ func TestHandleSendsContextToOTEL(t *testing.T) {
ID: 123,
Email: "test@example.com",
}
return viewer.NewContext(ctx, viewer.Viewer{User: testUser})
v := viewer.Viewer{User: testUser}
ctx = viewer.NewContext(ctx, v)
// Register the viewer as an error context provider
ctx = AddErrorContextProvider(ctx, &v)
return ctx
},
errorMessage: "test error with user context",
expectedAttrs: map[string]any{
@ -43,7 +47,10 @@ func TestHandleSendsContextToOTEL(t *testing.T) {
ID: 456,
Hostname: "test-host.example.com",
}
return host.NewContext(ctx, testHost)
ctx = host.NewContext(ctx, testHost)
// Register the host as an error context provider
ctx = AddErrorContextProvider(ctx, &host.HostAttributeProvider{Host: testHost})
return ctx
},
errorMessage: "test error with host context",
expectedAttrs: map[string]any{

View file

@ -322,7 +322,10 @@ func TestAdditionalMetadata(t *testing.T) {
t.Run("saves additional data about the host if present", func(t *testing.T) {
ctx, cleanup := setup()
defer cleanup()
hctx := host.NewContext(ctx, &fleet.Host{Platform: "test_platform", OsqueryVersion: "5.0"})
h := &fleet.Host{Platform: "test_platform", OsqueryVersion: "5.0"}
hctx := host.NewContext(ctx, h)
// Register the host as an error context provider
hctx = AddErrorContextProvider(hctx, &host.HostAttributeProvider{Host: h})
err := New(hctx, "with host context").(*FleetError)
require.JSONEq(t, string(err.data), `{"host":{"osquery_version":"5.0","platform":"test_platform"},"timestamp":"1969-06-19T21:44:05Z"}`)
@ -331,7 +334,10 @@ func TestAdditionalMetadata(t *testing.T) {
t.Run("saves additional data about the viewer if present", func(t *testing.T) {
ctx, cleanup := setup()
defer cleanup()
vctx := viewer.NewContext(ctx, viewer.Viewer{Session: &fleet.Session{ID: 1}, User: &fleet.User{SSOEnabled: true}})
v := viewer.Viewer{Session: &fleet.Session{ID: 1}, User: &fleet.User{SSOEnabled: true}}
vctx := viewer.NewContext(ctx, v)
// Register the viewer as an error context provider
vctx = AddErrorContextProvider(vctx, &v)
err := New(vctx, "with host context").(*FleetError)
require.JSONEq(t, string(err.data), `{"viewer":{"is_logged_in":true,"sso_enabled":true},"timestamp":"1969-06-19T21:44:05Z"}`)

View file

@ -0,0 +1,35 @@
package ctxerr
import "context"
// ErrorContextProvider provides contextual information for error handling.
// Implementations can provide data for both error storage and telemetry systems.
type ErrorContextProvider interface {
// GetDiagnosticContext returns attributes stored with errors for troubleshooting.
// Data is persisted to Redis and included in logs. Should contain diagnostic
// information like platform, versions, and status flags. Avoid including PII.
GetDiagnosticContext() map[string]any
// GetTelemetryContext returns attributes sent to observability systems
// (OpenTelemetry, Sentry). May include identifiers not stored with errors.
// Return nil if no telemetry context is available.
GetTelemetryContext() map[string]any
}
type errorContextProvidersKey struct{}
// AddErrorContextProvider returns a new context with the given provider added to
// the existing providers. This is useful when you want to add a provider
// without replacing existing ones.
func AddErrorContextProvider(ctx context.Context, provider ErrorContextProvider) context.Context {
existing := getErrorContextProviders(ctx)
providers := make([]ErrorContextProvider, len(existing)+1)
copy(providers, existing)
providers[len(existing)] = provider
return context.WithValue(ctx, errorContextProvidersKey{}, providers)
}
func getErrorContextProviders(ctx context.Context) []ErrorContextProvider {
providers, _ := ctx.Value(errorContextProvidersKey{}).([]ErrorContextProvider)
return providers
}

View file

@ -3,8 +3,6 @@ package ctxerr
import (
"context"
"encoding/json"
"github.com/fleetdm/fleet/v4/server/fleet"
)
type ErrorAgg struct {
@ -63,12 +61,21 @@ out:
return stack[:stackIdx]
}
// vitalErrorData represents the structure of vital fleetd error data.
type vitalErrorData struct {
ErrorSource string `json:"error_source"`
ErrorSourceVersion string `json:"error_source_version"`
ErrorMessage string `json:"error_message"`
ErrorAdditionalInfo map[string]any `json:"error_additional_info"`
Vital bool `json:"vital"`
}
func getVitalMetadata(chain []fleetErrorJSON) json.RawMessage {
for _, e := range chain {
if len(e.Data) > 0 {
// Currently, only vital fleetd errors contain metadata.
// Note: vital errors should not contain any sensitive info
var fleetdErr fleet.FleetdError
var fleetdErr vitalErrorData
var err error
if err = json.Unmarshal(e.Data, &fleetdErr); err != nil || !fleetdErr.Vital {
continue

View file

@ -22,3 +22,33 @@ func FromContext(ctx context.Context) (*fleet.Host, bool) {
host, ok := ctx.Value(hostKey).(*fleet.Host)
return host, ok
}
// HostAttributeProvider wraps a fleet.Host to provide error context.
// It implements ctxerr.ErrorContextProvider.
type HostAttributeProvider struct {
Host *fleet.Host
}
// GetDiagnosticContext implements ctxerr.ErrorContextProvider
func (p *HostAttributeProvider) GetDiagnosticContext() map[string]any {
if p.Host == nil {
return nil
}
return map[string]any{
"host": map[string]any{
"platform": p.Host.Platform,
"osquery_version": p.Host.OsqueryVersion,
},
}
}
// GetTelemetryContext implements ctxerr.ErrorContextProvider
func (p *HostAttributeProvider) GetTelemetryContext() map[string]any {
if p.Host == nil {
return nil
}
return map[string]any{
"host.hostname": p.Host.Hostname,
"host.id": p.Host.ID,
}
}

View file

@ -4,23 +4,34 @@ package license
import (
"context"
"github.com/fleetdm/fleet/v4/server/fleet"
)
// LicenseChecker is the interface for checking license properties.
// This interface is implemented by fleet.LicenseInfo
type LicenseChecker interface {
IsPremium() bool
IsAllowDisableTelemetry() bool
// GetTier returns the license tier (e.g., "free", "premium", "trial").
GetTier() string
// GetOrganization returns the name of the licensed organization.
GetOrganization() string
// GetDeviceCount returns the number of licensed devices.
GetDeviceCount() int
}
type key int
const licenseKey key = 0
// NewContext creates a new context.Context with the license.
func NewContext(ctx context.Context, lic *fleet.LicenseInfo) context.Context {
func NewContext(ctx context.Context, lic LicenseChecker) context.Context {
return context.WithValue(ctx, licenseKey, lic)
}
// FromContext returns the license from the context and true, or nil and false
// if there is no license.
func FromContext(ctx context.Context) (*fleet.LicenseInfo, bool) {
v, ok := ctx.Value(licenseKey).(*fleet.LicenseInfo)
// FromContext returns the license from the context as a LicenseChecker interface.
// Use this when you only need to check license properties via the interface methods.
func FromContext(ctx context.Context) (LicenseChecker, bool) {
v, ok := ctx.Value(licenseKey).(LicenseChecker)
return v, ok
}
@ -34,6 +45,8 @@ func IsPremium(ctx context.Context) bool {
return false
}
// IsAllowDisableTelemetry returns true if telemetry can be disabled based on
// the license in the context.
func IsAllowDisableTelemetry(ctx context.Context) bool {
if lic, ok := FromContext(ctx); ok {
return lic.IsAllowDisableTelemetry()

View file

@ -7,13 +7,31 @@ import (
"sync"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
kithttp "github.com/go-kit/kit/transport/http"
kitlog "github.com/go-kit/log"
"github.com/go-kit/log/level"
)
// UserEmailer provides the user's email for logging purposes.
type UserEmailer interface {
Email() string
}
type userEmailerKey struct{}
// WithUserEmailer returns a context with the UserEmailer stored for logging.
// This should be called by authentication middleware after the user is identified.
func WithUserEmailer(ctx context.Context, emailer UserEmailer) context.Context {
return context.WithValue(ctx, userEmailerKey{}, emailer)
}
// UserEmailerFromContext retrieves the UserEmailer from the context.
func UserEmailerFromContext(ctx context.Context) (UserEmailer, bool) {
v, ok := ctx.Value(userEmailerKey{}).(UserEmailer)
return v, ok
}
type key int
const loggingKey key = 0
@ -138,9 +156,8 @@ func (l *LoggingContext) Log(ctx context.Context, logger kitlog.Logger) {
if !l.SkipUser {
loggedInUser := "unauthenticated"
vc, ok := viewer.FromContext(ctx)
if ok {
loggedInUser = vc.Email()
if emailer, ok := UserEmailerFromContext(ctx); ok {
loggedInUser = emailer.Email()
}
keyvals = append(keyvals, "user", loggedInUser)
}
@ -168,7 +185,7 @@ func (l *LoggingContext) Log(ctx context.Context, logger kitlog.Logger) {
)
separator := " || "
for _, err := range l.Errs {
var ewi fleet.ErrWithInternal
var ewi platform_http.ErrWithInternal
if errors.As(err, &ewi) {
if internalErrs == "" {
internalErrs = ewi.Internal()
@ -182,7 +199,7 @@ func (l *LoggingContext) Log(ctx context.Context, logger kitlog.Logger) {
errs += separator + err.Error()
}
}
var ewuuid fleet.ErrorUUIDer
var ewuuid platform_http.ErrorUUIDer
if errors.As(err, &ewuuid) {
if uuid := ewuuid.UUID(); uuid != "" {
uuids = append(uuids, uuid)
@ -209,7 +226,7 @@ func (l *LoggingContext) setLevelError() bool {
}
if len(l.Errs) == 1 {
var ew fleet.ErrWithIsClientError
var ew platform_http.ErrWithIsClientError
if errors.As(l.Errs[0], &ew) && ew.IsClientError() {
return false
}

View file

@ -4,6 +4,7 @@ package viewer
import (
"context"
"strings"
"github.com/fleetdm/fleet/v4/server/fleet"
)
@ -102,3 +103,38 @@ func (v Viewer) CanPerformPasswordReset() bool {
}
return false
}
// GetDiagnosticContext implements ctxerr.ErrorContextProvider
func (v *Viewer) GetDiagnosticContext() map[string]any {
vdata := map[string]any{
"is_logged_in": v.IsLoggedIn(),
}
if v.User != nil {
vdata["sso_enabled"] = v.User.SSOEnabled
}
return map[string]any{
"viewer": vdata,
}
}
// GetTelemetryContext implements ctxerr.ErrorContextProvider
func (v *Viewer) GetTelemetryContext() map[string]any {
if v.User == nil {
return nil
}
return map[string]any{
"user.id": v.User.ID,
"user.email": maskEmail(v.User.Email),
}
}
// maskEmail anonymizes an email address for telemetry by showing only
// the first character of the local part and the full domain.
// Example: "john.doe@example.com" -> "j***@example.com"
func maskEmail(email string) string {
parts := strings.SplitN(email, "@", 2)
if len(parts) != 2 || len(parts[0]) == 0 {
return "***"
}
return string(parts[0][0]) + "***@" + parts[1]
}

View file

@ -153,3 +153,26 @@ func TestCanPerformActions(t *testing.T) {
// assert.Equal(t, false, needsPasswordResetAdminViewer.CanPerformWriteActionOnUser(1))
// assert.Equal(t, true, needsPasswordResetAdminViewer.CanPerformWriteActionOnUser(needsPasswordResetAdminViewer.User.ID))
// }
func TestMaskEmail(t *testing.T) {
cases := []struct {
name string
email string
expected string
}{
{"standard email", "john.doe@example.com", "j***@example.com"},
{"single char local", "j@example.com", "j***@example.com"},
{"subdomain", "user@mail.example.com", "u***@mail.example.com"},
{"empty string", "", "***"},
{"no at sign", "invalid", "***"},
{"empty local part", "@example.com", "***"},
{"only at sign", "@", "***"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result := maskEmail(tc.email)
assert.Equal(t, tc.expected, result)
})
}
}

View file

@ -144,7 +144,7 @@ func (ds *Datastore) ShouldSendStatistics(ctx context.Context, frequency time.Du
stats.NumHostsNotResponding = amountHostsNotResponding
stats.Organization = "unknown"
if lic != nil && lic.IsPremium() {
stats.Organization = lic.Organization
stats.Organization = lic.GetOrganization()
}
stats.AIFeaturesDisabled = appConfig.ServerSettings.AIFeaturesDisabled
stats.MaintenanceWindowsConfigured = len(appConfig.Integrations.GoogleCalendar) > 0 && appConfig.Integrations.GoogleCalendar[0].Domain != "" && len(appConfig.Integrations.GoogleCalendar[0].ApiKey) > 0
@ -187,7 +187,7 @@ func (ds *Datastore) ShouldSendStatistics(ctx context.Context, frequency time.Du
LicenseTier: fleet.TierFree,
}
if lic != nil {
stats.LicenseTier = lic.Tier
stats.LicenseTier = lic.GetTier()
}
if err := computeStats(&stats, time.Now().Add(-frequency)); err != nil {
return fleet.StatisticsPayload{}, false, ctxerr.Wrap(ctx, err, "compute statistics")
@ -212,7 +212,7 @@ func (ds *Datastore) ShouldSendStatistics(ctx context.Context, frequency time.Du
LicenseTier: fleet.TierFree,
}
if lic != nil {
stats.LicenseTier = lic.Tier
stats.LicenseTier = lic.GetTier()
}
if err := computeStats(&stats, lastUpdated); err != nil {
return fleet.StatisticsPayload{}, false, ctxerr.Wrap(ctx, err, "compute statistics")

View file

@ -1476,6 +1476,24 @@ func (l *LicenseInfo) IsAllowDisableTelemetry() bool {
return !l.IsPremium() || l.AllowDisableTelemetry
}
// Tier returns the license tier.
// This method implements license.LicenseChecker.
func (l *LicenseInfo) GetTier() string {
return l.Tier
}
// Organization returns the name of the licensed organization.
// This method implements license.LicenseChecker.
func (l *LicenseInfo) GetOrganization() string {
return l.Organization
}
// DeviceCount returns the number of licensed devices.
// This method implements license.LicenseChecker.
func (l *LicenseInfo) GetDeviceCount() int {
return l.DeviceCount
}
const (
HeaderLicenseKey = "X-Fleet-License"
HeaderLicenseValueExpired = "Expired"

View file

@ -19,6 +19,7 @@ import (
"github.com/fleetdm/fleet/v4/server/mdm/nanodep/godep"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/storage"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
"github.com/jmoiron/sqlx"
)
@ -2836,19 +2837,11 @@ type AlreadyExistsError interface {
IsExists() bool
}
// ForeignKeyError is returned when the operation fails due to foreign key constraints.
type ForeignKeyError interface {
error
IsForeignKey() bool
}
// ForeignKeyError is an alias for platform_http.ForeignKeyError.
type ForeignKeyError = platform_http.ForeignKeyError
func IsForeignKey(err error) bool {
var fke ForeignKeyError
if errors.As(err, &fke) {
return fke.IsForeignKey()
}
return false
}
// IsForeignKey is an alias for platform_http.IsForeignKey.
var IsForeignKey = platform_http.IsForeignKey
type OptionalArg func() interface{}

View file

@ -1,22 +1,18 @@
package fleet
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"
"regexp"
"strings"
"time"
"github.com/google/uuid"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
"github.com/rs/zerolog"
)
var (
ErrNoContext = errors.New("context key not set")
ErrPasswordResetRequired = &passwordResetRequiredError{}
ErrPasswordResetRequired = platform_http.ErrPasswordResetRequired
ErrMissingLicense = &licenseError{}
ErrMDMNotConfigured = &MDMNotConfiguredError{}
ErrWindowsMDMNotConfigured = &WindowsMDMNotConfiguredError{}
@ -46,40 +42,17 @@ type ErrWithStatusCode interface {
StatusCode() int
}
// ErrWithInternal is an interface for errors that include extra "internal"
// information that should be logged in server logs but not sent to clients.
type ErrWithInternal interface {
error
// Internal returns the error string that must only be logged internally,
// not returned to the client.
Internal() string
}
// ErrWithInternal is an alias for platform_http.ErrWithInternal.
type ErrWithInternal = platform_http.ErrWithInternal
// ErrWithLogFields is an interface for errors that include additional logging
// fields that should be logged in server logs but not sent to clients.
type ErrWithLogFields interface {
error
// LogFields returns the additional log fields to add, which should come in
// key, value pairs (as used in go-kit log).
LogFields() []interface{}
}
// ErrWithLogFields is an alias for platform_http.ErrWithLogFields.
type ErrWithLogFields = platform_http.ErrWithLogFields
// ErrWithRetryAfter is an interface for errors that should set a specific HTTP
// Header Retry-After value (see
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After)
type ErrWithRetryAfter interface {
error
// RetryAfter returns the number of seconds to wait before retry.
RetryAfter() int
}
// ErrWithRetryAfter is an alias for platform_http.ErrWithRetryAfter.
type ErrWithRetryAfter = platform_http.ErrWithRetryAfter
// ErrWithIsClientError is an interface for errors that explicitly specify
// whether they are client errors or not. By default, errors are treated as
// server errors.
type ErrWithIsClientError interface {
error
IsClientError() bool
}
// ErrWithIsClientError is an alias for platform_http.ErrWithIsClientError.
type ErrWithIsClientError = platform_http.ErrWithIsClientError
type invalidArgWithStatusError struct {
InvalidArgumentError
@ -94,30 +67,11 @@ func (e invalidArgWithStatusError) Status() int {
return e.code
}
// ErrorUUIDer is the interface for errors that contain a UUID.
type ErrorUUIDer interface {
// UUID returns the error's UUID.
UUID() string
}
// ErrorUUIDer is an alias for platform_http.ErrorUUIDer.
type ErrorUUIDer = platform_http.ErrorUUIDer
// ErrorWithUUID can be embedded to error types to implement ErrorUUIDer.
type ErrorWithUUID struct {
uuid string
}
var _ ErrorUUIDer = (*ErrorWithUUID)(nil)
// UUID implements the ErrorUUIDer interface.
func (e *ErrorWithUUID) UUID() string {
if e.uuid == "" {
uuid, err := uuid.NewRandom()
if err != nil {
panic(err)
}
e.uuid = uuid.String()
}
return e.uuid
}
// ErrorWithUUID is an alias for platform_http.ErrorWithUUID.
type ErrorWithUUID = platform_http.ErrorWithUUID
// InvalidArgumentError is the error returned when invalid data is presented to
// a service method. It is a client error.
@ -187,102 +141,26 @@ func (e InvalidArgumentError) Invalid() []map[string]string {
return invalid
}
// BadRequestError is an error type that generates a 400 status code.
type BadRequestError struct {
Message string
InternalErr error
// BadRequestError is an alias for platform_http.BadRequestError.
type BadRequestError = platform_http.BadRequestError
ErrorWithUUID
}
// AuthFailedError is an alias for platform_http.AuthFailedError.
type AuthFailedError = platform_http.AuthFailedError
// Error returns the error message.
func (e *BadRequestError) Error() string {
return e.Message
}
// NewAuthFailedError is an alias for platform_http.NewAuthFailedError.
var NewAuthFailedError = platform_http.NewAuthFailedError
// This implements the interface required by the server/service package logic
// to determine the status code to return to the client.
func (e *BadRequestError) BadRequestError() []map[string]string {
return nil
}
// AuthRequiredError is an alias for platform_http.AuthRequiredError.
type AuthRequiredError = platform_http.AuthRequiredError
func (e BadRequestError) Internal() string {
if e.InternalErr == nil {
return ""
}
return e.InternalErr.Error()
}
// NewAuthRequiredError is an alias for platform_http.NewAuthRequiredError.
var NewAuthRequiredError = platform_http.NewAuthRequiredError
type AuthFailedError struct {
// internal is the reason that should only be logged internally
internal string
// AuthHeaderRequiredError is an alias for platform_http.AuthHeaderRequiredError.
type AuthHeaderRequiredError = platform_http.AuthHeaderRequiredError
ErrorWithUUID
}
func NewAuthFailedError(internal string) *AuthFailedError {
return &AuthFailedError{internal: internal}
}
func (e AuthFailedError) Error() string {
return "Authentication failed"
}
func (e AuthFailedError) Internal() string {
return e.internal
}
func (e AuthFailedError) StatusCode() int {
return http.StatusUnauthorized
}
type AuthRequiredError struct {
// internal is the reason that should only be logged internally
internal string
ErrorWithUUID
}
func NewAuthRequiredError(internal string) *AuthRequiredError {
return &AuthRequiredError{internal: internal}
}
func (e AuthRequiredError) Error() string {
return "Authentication required"
}
func (e AuthRequiredError) Internal() string {
return e.internal
}
func (e AuthRequiredError) StatusCode() int {
return http.StatusUnauthorized
}
type AuthHeaderRequiredError struct {
// internal is the reason that should only be logged internally
internal string
ErrorWithUUID
}
func NewAuthHeaderRequiredError(internal string) *AuthHeaderRequiredError {
return &AuthHeaderRequiredError{
internal: internal,
}
}
func (e AuthHeaderRequiredError) Error() string {
return "Authorization header required"
}
func (e AuthHeaderRequiredError) Internal() string {
return e.internal
}
func (e AuthHeaderRequiredError) StatusCode() int {
return http.StatusUnauthorized
}
// NewAuthHeaderRequiredError is an alias for platform_http.NewAuthHeaderRequiredError.
var NewAuthHeaderRequiredError = platform_http.NewAuthHeaderRequiredError
// PermissionError, set when user is authenticated, but not allowed to perform action
type PermissionError struct {
@ -347,18 +225,6 @@ func (e licenseError) StatusCode() int {
return http.StatusPaymentRequired
}
type passwordResetRequiredError struct {
ErrorWithUUID
}
func (e passwordResetRequiredError) Error() string {
return "password reset required"
}
func (e passwordResetRequiredError) StatusCode() int {
return http.StatusUnauthorized
}
// MDMNotConfiguredError is used when an MDM endpoint or resource is accessed
// without having MDM correctly configured.
type MDMNotConfiguredError struct{}
@ -453,15 +319,9 @@ func (e *GatewayError) Error() string {
return msg
}
// Error is a user facing error (API user). It's meant to be used for errors that are
// related to fleet logic specifically. Other errors, such as mysql errors, shouldn't
// be translated to this.
type Error struct {
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
ErrorWithUUID
}
// Error is an alias for platform_http.Error.
// It's meant to be used for errors that are related to fleet logic specifically.
type Error = platform_http.Error
const (
// ErrNoRoleNeeded is the error number for valid role needed
@ -491,101 +351,26 @@ func NewErrorf(code int, format string, args ...interface{}) error {
}
}
func (ge *Error) Error() string {
return ge.Message
}
// UserMessageError is an alias for platform_http.UserMessageError.
type UserMessageError = platform_http.UserMessageError
// UserMessageError is an error that adds the UserMessage interface
// implementation.
type UserMessageError struct {
error
statusCode int
ErrorWithUUID
}
// NewUserMessageError creates a UserMessageError that will translate the
// error message of err to a user-friendly form. If statusCode is > 0, it
// will be used as the HTTP status code for the error, otherwise it defaults
// to http.StatusUnprocessableEntity (422).
func NewUserMessageError(err error, statusCode int) *UserMessageError {
if err == nil {
return nil
}
return &UserMessageError{
error: err,
statusCode: statusCode,
}
}
var rxJSONUnknownField = regexp.MustCompile(`^json: unknown field "(.+)"$`)
// NewUserMessageError is an alias for platform_http.NewUserMessageError.
var NewUserMessageError = platform_http.NewUserMessageError
// IsJSONUnknownFieldError returns true if err is a JSON unknown field error.
// There is no exported type or value for this error, so we have to match the
// error message.
func IsJSONUnknownFieldError(err error) bool {
return rxJSONUnknownField.MatchString(err.Error())
return platform_http.IsJSONUnknownFieldError(err)
}
// GetJSONUnknownField returns the unknown field name from a JSON unknown field error.
func GetJSONUnknownField(err error) *string {
errCause := Cause(err)
if IsJSONUnknownFieldError(errCause) {
substr := rxJSONUnknownField.FindStringSubmatch(errCause.Error())
return &substr[1]
}
return nil
}
// UserMessage implements the user-friendly translation of the error if its
// root cause is one of the supported types, otherwise it returns the error
// message.
func (e UserMessageError) UserMessage() string {
cause := Cause(e.error)
switch cause := cause.(type) {
case *json.UnmarshalTypeError:
var sb strings.Builder
curType := cause.Type
for curType.Kind() == reflect.Slice || curType.Kind() == reflect.Array {
sb.WriteString("array of ")
curType = curType.Elem()
}
sb.WriteString(curType.Name())
if curType != cause.Type {
// it was an array
sb.WriteString("s")
}
return fmt.Sprintf("invalid value type at '%s': expected %s but got %s", cause.Field, sb.String(), cause.Value)
default:
// there's no specific error type for the strict json mode
// (DisallowUnknownFields), so resort to message-matching.
if matches := rxJSONUnknownField.FindStringSubmatch(cause.Error()); matches != nil {
return fmt.Sprintf("unsupported key provided: %q", matches[1])
}
return e.Error()
}
}
// StatusCode implements the kithttp.StatusCoder interface to return the status
// code to use in HTTP API responses.
func (e UserMessageError) StatusCode() int {
if e.statusCode > 0 {
return e.statusCode
}
return http.StatusUnprocessableEntity
return platform_http.GetJSONUnknownField(err)
}
// Cause returns the root error in err's chain.
func Cause(err error) error {
for {
uerr := errors.Unwrap(err)
if uerr == nil {
return err
}
err = uerr
}
}
var Cause = platform_http.Cause
// FleetdError is an error that can be reported by any of the fleetd
// components.

View file

@ -7,32 +7,34 @@ import (
"github.com/fleetdm/fleet/v4/server/archtest"
)
const m = archtest.ModuleName
// TestAllAndroidPackageDependencies checks that android packages are not dependent on other Fleet packages
// to maintain decoupling and modularity.
// If coupling is necessary, it should be done in the main server/fleet, server/service, or other package.
func TestAllAndroidPackageDependencies(t *testing.T) {
t.Parallel()
archtest.NewPackageTest(t, "github.com/fleetdm/fleet/v4/server/mdm/android...").
archtest.NewPackageTest(t, m+"/server/mdm/android...").
OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)).
ShouldNotDependOn(
"github.com/fleetdm/fleet/v4/server/service...",
"github.com/fleetdm/fleet/v4/server/datastore/mysql...",
m+"/server/service...",
m+"/server/datastore/mysql...",
).
IgnoreRecursively(
"github.com/fleetdm/fleet/v4/server/mdm/android/tests",
m+"/server/mdm/android/tests",
).
IgnoreDeps(
// Android packages
"github.com/fleetdm/fleet/v4/server/mdm/android...",
m+"/server/mdm/android...",
// Other/infra packages
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql",
"github.com/fleetdm/fleet/v4/server/service/externalsvc", // dependency on Jira and Zendesk
"github.com/fleetdm/fleet/v4/server/service/middleware/auth",
"github.com/fleetdm/fleet/v4/server/service/middleware/authzcheck",
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils",
"github.com/fleetdm/fleet/v4/server/service/middleware/log",
"github.com/fleetdm/fleet/v4/server/service/middleware/ratelimit",
"github.com/fleetdm/fleet/v4/server/service/modules/activities",
m+"/server/datastore/mysql/common_mysql",
m+"/server/service/externalsvc", // dependency on Jira and Zendesk
m+"/server/service/middleware/auth",
m+"/server/service/middleware/authzcheck",
m+"/server/service/middleware/endpoint_utils",
m+"/server/service/middleware/log",
m+"/server/service/middleware/ratelimit",
m+"/server/service/modules/activities",
).
Check()
}
@ -42,7 +44,8 @@ func TestAllAndroidPackageDependencies(t *testing.T) {
// If coupling is necessary, it should be done in the main server/fleet or another package.
func TestAndroidPackageDependencies(t *testing.T) {
t.Parallel()
archtest.NewPackageTest(t, "github.com/fleetdm/fleet/v4/server/mdm/android").
archtest.NewPackageTest(t, m+"/server/mdm/android").
OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)).
ShouldNotDependOn("github.com/fleetdm/fleet/v4/...")
ShouldNotDependOn(m + "/...").
Check()
}

View file

@ -7,6 +7,7 @@ import (
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mdm/android"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
"github.com/fleetdm/fleet/v4/server/service/middleware/auth"
eu "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
"github.com/go-json-experiment/json"
@ -21,13 +22,14 @@ func encodeResponse(ctx context.Context, w http.ResponseWriter, response interfa
func(w http.ResponseWriter, response interface{}) error {
return json.MarshalWrite(w, response, jsontext.WithIndent(" "))
},
nil, // no domain-specific error encoder
)
}
func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
return eu.MakeDecoder(iface, func(body io.Reader, req any) error {
return json.UnmarshalRead(body, req)
}, nil, nil, nil)
}, nil, nil, nil, nil)
}
// handlerFunc is the handler function type for Android service endpoints.
@ -40,12 +42,12 @@ type endpointer struct {
svc android.Service
}
func (e *endpointer) CallHandlerFunc(f handlerFunc, ctx context.Context, request interface{},
svc interface{}) (fleet.Errorer, error) {
func (e *endpointer) CallHandlerFunc(f handlerFunc, ctx context.Context, request any,
svc any) (platform_http.Errorer, error) {
return f(ctx, request, svc.(android.Service)), nil
}
func (e *endpointer) Service() interface{} {
func (e *endpointer) Service() any {
return e.svc
}

View file

@ -207,13 +207,18 @@ func (m *mockService) NewActivity(ctx context.Context, user *fleet.User, details
}
func runServerForTests(t *testing.T, logger kitlog.Logger, fleetSvc fleet.Service, androidSvc android.Service) *httptest.Server {
// androidErrorEncoder wraps EncodeError with nil domain encoder for android tests
androidErrorEncoder := func(ctx context.Context, err error, w http.ResponseWriter) {
endpoint_utils.EncodeError(ctx, err, w, nil)
}
fleetAPIOptions := []kithttp.ServerOption{
kithttp.ServerBefore(
kithttp.PopulateRequestContext,
auth.SetRequestsContexts(fleetSvc),
),
kithttp.ServerErrorHandler(&endpoint_utils.ErrorHandler{Logger: logger}),
kithttp.ServerErrorEncoder(endpoint_utils.EncodeError),
kithttp.ServerErrorEncoder(androidErrorEncoder),
kithttp.ServerAfter(
kithttp.SetContentType("application/json; charset=utf-8"),
log.LogRequestEnd(logger),

View file

@ -0,0 +1,44 @@
package platform_test
import (
"regexp"
"testing"
"github.com/fleetdm/fleet/v4/server/archtest"
)
const m = archtest.ModuleName
// TestEndpointerPackageDependencies checks that endpointer package is not dependent on other Fleet domain packages
// to maintain decoupling and modularity.
func TestEndpointerPackageDependencies(t *testing.T) {
t.Parallel()
archtest.NewPackageTest(t, m+"/server/service/middleware/endpoint_utils").
OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)).
WithTests().
ShouldNotDependOn(m+"/...").
IgnoreDeps(
// Platform packages
m+"/server/platform...",
// Other infra packages
m+"/server/contexts/authz",
m+"/server/contexts/ctxerr",
m+"/server/contexts/license",
m+"/server/contexts/logging",
m+"/server/contexts/publicip",
m+"/server/service/middleware/authzcheck",
m+"/server/service/middleware/ratelimit",
).
Check()
}
// TestPlatformPackageDependencies checks that platform packages are NOT dependent on ANY other Fleet packages
// to maintain decoupling and modularity.
func TestPlatformPackageDependencies(t *testing.T) {
t.Parallel()
archtest.NewPackageTest(t, m+"/server/platform...").
OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)).
WithTests().
ShouldNotDependOn(m + "/...").
Check()
}

View file

@ -0,0 +1,357 @@
package http
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"
"regexp"
"strings"
"github.com/google/uuid"
)
// ErrWithInternal defines an interface for errors that have an internal message
// that should only be logged, not returned to the client.
type ErrWithInternal interface {
error
// Internal returns the error string that must only be logged internally,
// not returned to the client.
Internal() string
}
// ErrWithLogFields defines an interface for errors that have additional log fields.
type ErrWithLogFields interface {
error
// LogFields returns the additional log fields to add, which should come in
// key, value pairs (as used in go-kit log).
LogFields() []any
}
// ErrorUUIDer defines an interface for errors that have a UUID for tracking.
type ErrorUUIDer interface {
// UUID returns the error's UUID.
UUID() string
}
// ErrorWithUUID can be embedded in error types to implement ErrorUUIDer.
type ErrorWithUUID struct {
uuid string
}
var _ ErrorUUIDer = (*ErrorWithUUID)(nil)
// UUID implements the ErrorUUIDer interface.
func (e *ErrorWithUUID) UUID() string {
if e.uuid == "" {
u, err := uuid.NewRandom()
if err != nil {
panic(err)
}
e.uuid = u.String()
}
return e.uuid
}
// BadRequestError is the error returned when the request is invalid.
type BadRequestError struct {
Message string
InternalErr error
ErrorWithUUID
}
// Error returns the error message.
func (e *BadRequestError) Error() string {
return e.Message
}
// BadRequestError implements the interface required by the server/service package logic
// to determine the status code to return to the client.
func (e *BadRequestError) BadRequestError() []map[string]string {
return nil
}
// Internal implements the ErrWithInternal interface.
func (e BadRequestError) Internal() string {
if e.InternalErr != nil {
return e.InternalErr.Error()
}
return ""
}
// UserMessageError is an error that wraps another error with a user-friendly message.
type UserMessageError struct {
error
statusCode int
ErrorWithUUID
}
// NewUserMessageError creates a UserMessageError that will translate the
// error message of err to a user-friendly form. If statusCode is > 0, it
// will be used as the HTTP status code for the error, otherwise it defaults
// to http.StatusUnprocessableEntity (422).
func NewUserMessageError(err error, statusCode int) *UserMessageError {
if err == nil {
return nil
}
return &UserMessageError{
error: err,
statusCode: statusCode,
}
}
// StatusCode returns the HTTP status code for this error.
func (e UserMessageError) StatusCode() int {
if e.statusCode > 0 {
return e.statusCode
}
return http.StatusUnprocessableEntity
}
var rxJSONUnknownField = regexp.MustCompile(`^json: unknown field "(.+)"$`)
// IsJSONUnknownFieldError returns true if err is a JSON unknown field error.
// There is no exported type or value for this error, so we have to match the
// error message.
func IsJSONUnknownFieldError(err error) bool {
return rxJSONUnknownField.MatchString(err.Error())
}
// GetJSONUnknownField returns the unknown field name from a JSON unknown field error.
func GetJSONUnknownField(err error) *string {
errCause := Cause(err)
if IsJSONUnknownFieldError(errCause) {
substr := rxJSONUnknownField.FindStringSubmatch(errCause.Error())
return &substr[1]
}
return nil
}
// UserMessage implements the user-friendly translation of the error if its
// root cause is one of the supported types, otherwise it returns the error
// message.
func (e UserMessageError) UserMessage() string {
cause := Cause(e.error)
switch cause := cause.(type) {
case *json.UnmarshalTypeError:
var sb strings.Builder
curType := cause.Type
for curType.Kind() == reflect.Slice || curType.Kind() == reflect.Array {
sb.WriteString("array of ")
curType = curType.Elem()
}
sb.WriteString(curType.Name())
if curType != cause.Type {
// it was an array
sb.WriteString("s")
}
return fmt.Sprintf("invalid value type at '%s': expected %s but got %s", cause.Field, sb.String(), cause.Value)
default:
// there's no specific error type for the strict json mode
// (DisallowUnknownFields), so resort to message-matching.
if matches := rxJSONUnknownField.FindStringSubmatch(cause.Error()); matches != nil {
return fmt.Sprintf("unsupported key provided: %q", matches[1])
}
return e.Error()
}
}
// Cause returns the root error in err's chain.
func Cause(err error) error {
for {
uerr := errors.Unwrap(err)
if uerr == nil {
return err
}
err = uerr
}
}
// ErrWithRetryAfter is an interface for errors that should set a specific HTTP
// Header Retry-After value (see
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After)
type ErrWithRetryAfter interface {
error
// RetryAfter returns the number of seconds to wait before retry.
RetryAfter() int
}
// ForeignKeyError is an interface for errors caused by foreign key constraint violations.
type ForeignKeyError interface {
error
IsForeignKey() bool
}
// IsForeignKey returns true if err is a foreign key constraint violation.
func IsForeignKey(err error) bool {
var fke ForeignKeyError
if errors.As(err, &fke) {
return fke.IsForeignKey()
}
return false
}
// Error is a generic error type with a code and message.
type Error struct {
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
ErrorWithUUID
}
// Error returns the error message.
func (e *Error) Error() string {
return e.Message
}
// ErrWithIsClientError is an interface for errors that explicitly specify
// whether they are client errors or not. By default, errors are treated as
// server errors.
type ErrWithIsClientError interface {
error
IsClientError() bool
}
// AuthFailedError is returned when authentication fails.
type AuthFailedError struct {
// internal is the reason that should only be logged internally
internal string
ErrorWithUUID
}
// NewAuthFailedError creates a new AuthFailedError.
func NewAuthFailedError(internal string) *AuthFailedError {
return &AuthFailedError{internal: internal}
}
// Error implements the error interface.
func (e AuthFailedError) Error() string {
return "Authentication failed"
}
// Internal implements ErrWithInternal.
func (e AuthFailedError) Internal() string {
return e.internal
}
// StatusCode implements kithttp.StatusCoder.
func (e AuthFailedError) StatusCode() int {
return http.StatusUnauthorized
}
// AuthRequiredError is returned when authentication is required.
type AuthRequiredError struct {
// internal is the reason that should only be logged internally
internal string
ErrorWithUUID
}
// NewAuthRequiredError creates a new AuthRequiredError.
func NewAuthRequiredError(internal string) *AuthRequiredError {
return &AuthRequiredError{internal: internal}
}
// Error implements the error interface.
func (e AuthRequiredError) Error() string {
return "Authentication required"
}
// Internal implements ErrWithInternal.
func (e AuthRequiredError) Internal() string {
return e.internal
}
// StatusCode implements kithttp.StatusCoder.
func (e AuthRequiredError) StatusCode() int {
return http.StatusUnauthorized
}
// AuthHeaderRequiredError is returned when an authorization header is required.
type AuthHeaderRequiredError struct {
// internal is the reason that should only be logged internally
internal string
ErrorWithUUID
}
// NewAuthHeaderRequiredError creates a new AuthHeaderRequiredError.
func NewAuthHeaderRequiredError(internal string) *AuthHeaderRequiredError {
return &AuthHeaderRequiredError{
internal: internal,
}
}
// Error implements the error interface.
func (e AuthHeaderRequiredError) Error() string {
return "Authorization header required"
}
// Internal implements ErrWithInternal.
func (e AuthHeaderRequiredError) Internal() string {
return e.internal
}
// StatusCode implements kithttp.StatusCoder.
func (e AuthHeaderRequiredError) StatusCode() int {
return http.StatusUnauthorized
}
// ErrPasswordResetRequired is returned when a password reset is required.
var ErrPasswordResetRequired = &passwordResetRequiredError{}
type passwordResetRequiredError struct {
ErrorWithUUID
}
// Error implements the error interface.
func (e passwordResetRequiredError) Error() string {
return "password reset required"
}
// StatusCode implements kithttp.StatusCoder.
func (e passwordResetRequiredError) StatusCode() int {
return http.StatusUnauthorized
}
// ForbiddenErrorMessage is the error message that should be returned to
// clients when an action is forbidden. It is intentionally vague to prevent
// disclosing information that a client should not have access to.
const ForbiddenErrorMessage = "forbidden"
// CheckMissing is the error to return when no authorization check was performed
// by the service.
type CheckMissing struct {
response any
ErrorWithUUID
}
// CheckMissingWithResponse creates a new error indicating the authorization
// check was missed, and including the response for further analysis by the error
// encoder.
func CheckMissingWithResponse(response any) *CheckMissing {
return &CheckMissing{response: response}
}
// Error implements the error interface.
func (e *CheckMissing) Error() string {
return ForbiddenErrorMessage
}
// Internal implements the ErrWithInternal interface.
func (e *CheckMissing) Internal() string {
return "Missing authorization check"
}
// Response returns the response that was generated before the authorization
// check was found to be missing.
func (e *CheckMissing) Response() any {
return e.response
}

View file

@ -0,0 +1,7 @@
// Package http provides HTTP types for bounded contexts.
package http
// Errorer is implemented by response types that may contain errors.
type Errorer interface {
Error() error
}

View file

@ -108,7 +108,7 @@ func getAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet.Se
if err != nil {
return nil, err
}
license, err := svc.License(ctx)
lic, err := svc.License(ctx)
if err != nil {
return nil, err
}
@ -183,7 +183,7 @@ func getAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet.Se
transparencyURL := fleet.DefaultTransparencyURL
// Fleet Premium license is required for custom transparency url
if license.IsPremium() && appConfig.FleetDesktop.TransparencyURL != "" {
if lic.IsPremium() && appConfig.FleetDesktop.TransparencyURL != "" {
transparencyURL = appConfig.FleetDesktop.TransparencyURL
}
fleetDesktop := fleet.FleetDesktopSettings{TransparencyURL: transparencyURL}
@ -218,7 +218,7 @@ func getAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet.Se
appConfigResponseFields: appConfigResponseFields{
UpdateInterval: updateIntervalConfig,
Vulnerabilities: vulnConfig,
License: license,
License: lic,
Logging: loggingConfig,
Email: emailConfig,
SandboxEnabled: svc.SandboxEnabled(),
@ -274,7 +274,8 @@ func modifyAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet
}
// We do not use svc.License(ctx) to allow roles (like GitOps) write but not read access to AppConfig.
license, _ := license.FromContext(ctx)
licChecker, _ := license.FromContext(ctx)
lic, _ := licChecker.(*fleet.LicenseInfo)
loggingConfig, err := svc.LoggingConfig(ctx)
if err != nil {
@ -283,14 +284,14 @@ func modifyAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet
response := appConfigResponse{
AppConfig: *appConfig,
appConfigResponseFields: appConfigResponseFields{
License: license,
License: lic,
Logging: loggingConfig,
},
}
response.Obfuscate()
if (!license.IsPremium()) || response.FleetDesktop.TransparencyURL == "" {
if lic == nil || (!lic.IsPremium()) || response.FleetDesktop.TransparencyURL == "" {
response.FleetDesktop.TransparencyURL = fleet.DefaultTransparencyURL
}
@ -319,7 +320,8 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
oldAppConfig := appConfig.Copy()
// We do not use svc.License(ctx) to allow roles (like GitOps) write but not read access to AppConfig.
license, _ := license.FromContext(ctx)
licChecker, _ := license.FromContext(ctx)
lic, _ := licChecker.(*fleet.LicenseInfo)
var oldSMTPSettings fleet.SMTPSettings
if appConfig.SMTPSettings != nil {
@ -359,7 +361,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
// default transparency URL is https://fleetdm.com/transparency so you are allowed to apply as long as it's not changing
if newAppConfig.FleetDesktop.TransparencyURL != "" && newAppConfig.FleetDesktop.TransparencyURL != fleet.DefaultTransparencyURL {
if !license.IsPremium() {
if !lic.IsPremium() {
invalid.Append("transparency_url", ErrMissingLicense.Error())
return nil, ctxerr.Wrap(ctx, invalid)
}
@ -370,7 +372,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
}
if newAppConfig.SSOSettings != nil {
validateSSOSettings(newAppConfig, appConfig, invalid, license)
validateSSOSettings(newAppConfig, appConfig, invalid, lic)
if invalid.HasErrors() {
return nil, ctxerr.Wrap(ctx, invalid)
}
@ -437,7 +439,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
appConfig.MDM.MacOSSetup.EnableReleaseDeviceManually = oldAppConfig.MDM.MacOSSetup.EnableReleaseDeviceManually
}
if appConfig.MDM.MacOSSetup.ManualAgentInstall.Valid && appConfig.MDM.MacOSSetup.ManualAgentInstall.Value {
if !license.IsPremium() {
if !lic.IsPremium() {
invalid.Append("macos_setup.manual_agent_install", ErrMissingLicense.Error())
return nil, ctxerr.Wrap(ctx, invalid)
}
@ -477,7 +479,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
if newAppConfig.AgentOptions != nil {
// if there were Agent Options in the new app config, then it replaced the
// agent options in the resulting app config, so validate those.
if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, *appConfig.AgentOptions, license.IsPremium(), 0); err != nil {
if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, *appConfig.AgentOptions, lic.IsPremium(), 0); err != nil {
err = fleet.SuggestAgentOptionsCorrection(err)
err = fleet.NewUserMessageError(err, http.StatusBadRequest)
if applyOpts.Force && !applyOpts.DryRun {
@ -490,7 +492,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
}
// If the license is Premium, we should always send usage statisics.
if !license.IsAllowDisableTelemetry() {
if !lic.IsAllowDisableTelemetry() {
appConfig.ServerSettings.EnableAnalytics = true
}
@ -532,7 +534,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
isNonEmpty(newAppConfig.ConditionalAccess.OktaAudienceURI) ||
isNonEmpty(newAppConfig.ConditionalAccess.OktaCertificate)
if oktaFieldsBeingSet && !license.IsPremium() {
if oktaFieldsBeingSet && !lic.IsPremium() {
invalid.Append("conditional_access", ErrMissingLicense.Error())
return nil, ctxerr.Wrap(ctx, invalid)
}
@ -626,11 +628,11 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
appConfig.Integrations.ConditionalAccessEnabled = newAppConfig.Integrations.ConditionalAccessEnabled
}
if err := svc.validateMDM(ctx, license, &oldAppConfig.MDM, &appConfig.MDM, invalid); err != nil {
if err := svc.validateMDM(ctx, lic, &oldAppConfig.MDM, &appConfig.MDM, invalid); err != nil {
return nil, ctxerr.Wrap(ctx, err, "validating MDM config")
}
abmAssignments, err := svc.validateABMAssignments(ctx, &newAppConfig.MDM, &oldAppConfig.MDM, invalid, license)
abmAssignments, err := svc.validateABMAssignments(ctx, &newAppConfig.MDM, &oldAppConfig.MDM, invalid, lic)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "validating ABM token assignments")
}
@ -638,7 +640,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
var vppAssignments map[uint][]uint
vppAssignmentsDefined := newAppConfig.MDM.VolumePurchasingProgram.Set && newAppConfig.MDM.VolumePurchasingProgram.Valid
if vppAssignmentsDefined {
vppAssignments, err = svc.validateVPPAssignments(ctx, newAppConfig.MDM.VolumePurchasingProgram.Value, invalid, license)
vppAssignments, err = svc.validateVPPAssignments(ctx, newAppConfig.MDM.VolumePurchasingProgram.Value, invalid, lic)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "validating VPP token assignments")
}
@ -738,7 +740,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
gitopsModeEnabled, gitopsRepoURL := appConfig.UIGitOpsMode.GitopsModeEnabled, appConfig.UIGitOpsMode.RepositoryURL
if gitopsModeEnabled {
if !license.IsPremium() {
if !lic.IsPremium() {
return nil, fleet.NewInvalidArgumentError("UI GitOpsMode: ", ErrMissingLicense.Error())
}
if gitopsRepoURL == "" {
@ -767,7 +769,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
}
if !license.IsPremium() {
if !lic.IsPremium() {
// reset transparency url to empty for downgraded licenses
appConfig.FleetDesktop.TransparencyURL = ""
}
@ -909,19 +911,19 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
//
// Process OS updates config changes for Apple devices.
//
if err := svc.processAppleOSUpdateSettings(ctx, license, fleet.MacOS,
if err := svc.processAppleOSUpdateSettings(ctx, lic, fleet.MacOS,
oldAppConfig.MDM.MacOSUpdates,
appConfig.MDM.MacOSUpdates,
); err != nil {
return nil, ctxerr.Wrap(ctx, err, "process macOS OS updates config change")
}
if err := svc.processAppleOSUpdateSettings(ctx, license, fleet.IOS,
if err := svc.processAppleOSUpdateSettings(ctx, lic, fleet.IOS,
oldAppConfig.MDM.IOSUpdates,
appConfig.MDM.IOSUpdates,
); err != nil {
return nil, ctxerr.Wrap(ctx, err, "process iOS OS updates config change")
}
if err := svc.processAppleOSUpdateSettings(ctx, license, fleet.IPadOS,
if err := svc.processAppleOSUpdateSettings(ctx, lic, fleet.IPadOS,
oldAppConfig.MDM.IPadOSUpdates,
appConfig.MDM.IPadOSUpdates,
); err != nil {
@ -1002,7 +1004,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
appConfig.MDM.EndUserAuthentication.SSOProviderSettings
serverURLChanged := oldAppConfig.ServerSettings.ServerURL != appConfig.ServerSettings.ServerURL
appleMDMUrlChanged := oldAppConfig.MDMUrl() != appConfig.MDMUrl()
if (mdmEnableEndUserAuthChanged || mdmSSOSettingsChanged || serverURLChanged || appleMDMUrlChanged) && license.IsPremium() {
if (mdmEnableEndUserAuthChanged || mdmSSOSettingsChanged || serverURLChanged || appleMDMUrlChanged) && lic.IsPremium() {
if err := svc.EnterpriseOverrides.MDMAppleSyncDEPProfiles(ctx); err != nil {
return nil, ctxerr.Wrap(ctx, err, "sync DEP profiles")
}
@ -1100,14 +1102,14 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
// processAppleOSUpdateSettings updates the OS updates configuration if the minimum version+deadline are updated.
func (svc *Service) processAppleOSUpdateSettings(
ctx context.Context,
license *fleet.LicenseInfo,
lic *fleet.LicenseInfo,
appleDevice fleet.AppleDevice,
oldOSUpdateSettings fleet.AppleOSUpdateSettings,
newOSUpdateSettings fleet.AppleOSUpdateSettings,
) error {
if oldOSUpdateSettings.MinimumVersion.Value != newOSUpdateSettings.MinimumVersion.Value ||
oldOSUpdateSettings.Deadline.Value != newOSUpdateSettings.Deadline.Value {
if license.IsPremium() {
if lic.IsPremium() {
if err := svc.EnterpriseOverrides.MDMAppleEditedAppleOSUpdates(ctx, nil, appleDevice, newOSUpdateSettings); err != nil {
return ctxerr.Wrap(ctx, err, "update DDM profile after Apple OS updates change")
}
@ -1174,33 +1176,33 @@ func (svc *Service) HasCustomSetupAssistantConfigurationWebURL(ctx context.Conte
func (svc *Service) validateMDM(
ctx context.Context,
license *fleet.LicenseInfo,
lic *fleet.LicenseInfo,
oldMdm *fleet.MDM,
mdm *fleet.MDM,
invalid *fleet.InvalidArgumentError,
) error {
if mdm.EnableDiskEncryption.Value && !license.IsPremium() {
if mdm.EnableDiskEncryption.Value && !lic.IsPremium() {
invalid.Append("macos_settings.enable_disk_encryption", ErrMissingLicense.Error())
}
if mdm.MacOSSetup.MacOSSetupAssistant.Value != "" && oldMdm.MacOSSetup.MacOSSetupAssistant.Value != mdm.MacOSSetup.MacOSSetupAssistant.Value && !license.IsPremium() {
if mdm.MacOSSetup.MacOSSetupAssistant.Value != "" && oldMdm.MacOSSetup.MacOSSetupAssistant.Value != mdm.MacOSSetup.MacOSSetupAssistant.Value && !lic.IsPremium() {
invalid.Append("macos_setup.macos_setup_assistant", ErrMissingLicense.Error())
}
if mdm.MacOSSetup.EnableReleaseDeviceManually.Value && oldMdm.MacOSSetup.EnableReleaseDeviceManually.Value != mdm.MacOSSetup.EnableReleaseDeviceManually.Value && !license.IsPremium() {
if mdm.MacOSSetup.EnableReleaseDeviceManually.Value && oldMdm.MacOSSetup.EnableReleaseDeviceManually.Value != mdm.MacOSSetup.EnableReleaseDeviceManually.Value && !lic.IsPremium() {
invalid.Append("macos_setup.enable_release_device_manually", ErrMissingLicense.Error())
}
if mdm.MacOSSetup.BootstrapPackage.Value != "" && oldMdm.MacOSSetup.BootstrapPackage.Value != mdm.MacOSSetup.BootstrapPackage.Value && !license.IsPremium() {
if mdm.MacOSSetup.BootstrapPackage.Value != "" && oldMdm.MacOSSetup.BootstrapPackage.Value != mdm.MacOSSetup.BootstrapPackage.Value && !lic.IsPremium() {
invalid.Append("macos_setup.bootstrap_package", ErrMissingLicense.Error())
}
if mdm.MacOSSetup.EnableEndUserAuthentication && oldMdm.MacOSSetup.EnableEndUserAuthentication != mdm.MacOSSetup.EnableEndUserAuthentication && !license.IsPremium() {
if mdm.MacOSSetup.EnableEndUserAuthentication && oldMdm.MacOSSetup.EnableEndUserAuthentication != mdm.MacOSSetup.EnableEndUserAuthentication && !lic.IsPremium() {
invalid.Append("macos_setup.enable_end_user_authentication", ErrMissingLicense.Error())
}
if mdm.MacOSSetup.ManualAgentInstall.Valid && oldMdm.MacOSSetup.ManualAgentInstall.Value != mdm.MacOSSetup.ManualAgentInstall.Value && !license.IsPremium() {
if mdm.MacOSSetup.ManualAgentInstall.Valid && oldMdm.MacOSSetup.ManualAgentInstall.Value != mdm.MacOSSetup.ManualAgentInstall.Value && !lic.IsPremium() {
invalid.Append("macos_setup.manual_agent_install", ErrMissingLicense.Error())
}
if mdm.WindowsMigrationEnabled && !license.IsPremium() {
if mdm.WindowsMigrationEnabled && !lic.IsPremium() {
invalid.Append("windows_migration_enabled", ErrMissingLicense.Error())
}
if mdm.EnableTurnOnWindowsMDMManually && !license.IsPremium() {
if mdm.EnableTurnOnWindowsMDMManually && !lic.IsPremium() {
invalid.Append("enable_turn_on_windows_mdm_manually", ErrMissingLicense.Error())
}
@ -1298,7 +1300,7 @@ func (svc *Service) validateMDM(
updatingIPadOSVersion || updatingIPadOSDeadline {
// TODO: Should we validate MDM configured on here too?
if !license.IsPremium() {
if !lic.IsPremium() {
invalid.Append("macos_updates.minimum_version", ErrMissingLicense.Error())
return nil
}
@ -1318,7 +1320,7 @@ func (svc *Service) validateMDM(
if updatingWindowsUpdates {
// TODO: Should we validate MDM configured on here too?
if !license.IsPremium() {
if !lic.IsPremium() {
invalid.Append("windows_updates.deadline_days", ErrMissingLicense.Error())
return nil
}
@ -1330,7 +1332,7 @@ func (svc *Service) validateMDM(
// EndUserAuthentication
// only validate SSO settings if they changed
if mdm.EndUserAuthentication.SSOProviderSettings != oldMdm.EndUserAuthentication.SSOProviderSettings {
if !license.IsPremium() {
if !lic.IsPremium() {
invalid.Append("end_user_authentication", ErrMissingLicense.Error())
return nil
}
@ -1366,7 +1368,7 @@ func (svc *Service) validateMDM(
// TODO: Should we validate MDM configured on here too?
if mdm.MacOSMigration.Enable {
if !license.IsPremium() {
if !lic.IsPremium() {
invalid.Append("macos_migration.enable", ErrMissingLicense.Error())
return nil
}
@ -1423,7 +1425,7 @@ func (svc *Service) validateABMAssignments(
ctx context.Context,
mdm, oldMdm *fleet.MDM,
invalid *fleet.InvalidArgumentError,
license *fleet.LicenseInfo,
lic *fleet.LicenseInfo,
) ([]*fleet.ABMToken, error) {
if mdm.DeprecatedAppleBMDefaultTeam != "" && mdm.AppleBusinessManager.Set && mdm.AppleBusinessManager.Valid {
invalid.Append("mdm.apple_bm_default_team", fleet.AppleABMDefaultTeamDeprecatedMessage)
@ -1431,7 +1433,7 @@ func (svc *Service) validateABMAssignments(
}
if name := mdm.DeprecatedAppleBMDefaultTeam; name != "" && name != oldMdm.DeprecatedAppleBMDefaultTeam {
if !license.IsPremium() {
if !lic.IsPremium() {
invalid.Append("mdm.apple_bm_default_team", ErrMissingLicense.Error())
return nil, nil
}
@ -1464,7 +1466,7 @@ func (svc *Service) validateABMAssignments(
}
if mdm.AppleBusinessManager.Set && len(mdm.AppleBusinessManager.Value) > 0 {
if !license.IsPremium() {
if !lic.IsPremium() {
invalid.Append("mdm.apple_business_manager", ErrMissingLicense.Error())
return nil, nil
}
@ -1525,14 +1527,14 @@ func (svc *Service) validateVPPAssignments(
ctx context.Context,
volumePurchasingProgramInfo []fleet.MDMAppleVolumePurchasingProgramInfo,
invalid *fleet.InvalidArgumentError,
license *fleet.LicenseInfo,
lic *fleet.LicenseInfo,
) (map[uint][]uint, error) {
// Allow clearing VPP assignments in free and premium.
if len(volumePurchasingProgramInfo) == 0 {
return nil, nil
}
if !license.IsPremium() {
if !lic.IsPremium() {
invalid.Append("mdm.volume_purchasing_program", ErrMissingLicense.Error())
return nil, nil
}
@ -1622,7 +1624,7 @@ func validateSSOProviderSettings(incoming, existing fleet.SSOProviderSettings, i
}
}
func validateSSOSettings(p fleet.AppConfig, existing *fleet.AppConfig, invalid *fleet.InvalidArgumentError, license *fleet.LicenseInfo) {
func validateSSOSettings(p fleet.AppConfig, existing *fleet.AppConfig, invalid *fleet.InvalidArgumentError, lic *fleet.LicenseInfo) {
if p.SSOSettings != nil && p.SSOSettings.EnableSSO {
var existingSSOProviderSettings fleet.SSOProviderSettings
@ -1631,7 +1633,7 @@ func validateSSOSettings(p fleet.AppConfig, existing *fleet.AppConfig, invalid *
}
validateSSOProviderSettings(p.SSOSettings.SSOProviderSettings, existingSSOProviderSettings, invalid)
if !license.IsPremium() {
if !lic.IsPremium() {
if p.SSOSettings.EnableJITProvisioning {
invalid.Append("enable_jit_provisioning", ErrMissingLicense.Error())
}

View file

@ -1793,7 +1793,7 @@ func (r mdmAppleEnrollResponse) HijackRender(ctx context.Context, w http.Respons
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
if err := json.NewEncoder(w).Encode(r.SoftwareUpdateRequired); err != nil {
endpoint_utils.EncodeError(ctx, ctxerr.New(ctx, "failed to encode software update required"), w)
encodeError(ctx, ctxerr.New(ctx, "failed to encode software update required"), w)
}
return
}

View file

@ -10,6 +10,7 @@ import (
"strings"
"github.com/fleetdm/fleet/v4/server/contexts/certserial"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/fleet"
middleware_log "github.com/fleetdm/fleet/v4/server/service/middleware/log"
@ -104,6 +105,10 @@ func authenticatedDevice(svc fleet.Service, logger log.Logger, next endpoint.End
}
ctx = hostctx.NewContext(ctx, host)
// Register host as error context provider for ctxerr enrichment
hostProvider := &hostctx.HostAttributeProvider{Host: host}
ctx = ctxerr.AddErrorContextProvider(ctx, hostProvider)
instrumentHostLogger(ctx, host.ID)
if ac, ok := authz_ctx.FromContext(ctx); ok {
ac.SetAuthnMethod(authnMethod)
@ -151,6 +156,10 @@ func authenticatedHost(svc fleet.Service, logger log.Logger, next endpoint.Endpo
}
ctx = hostctx.NewContext(ctx, host)
// Register host as error context provider for ctxerr enrichment
hostProvider := &hostctx.HostAttributeProvider{Host: host}
ctx = ctxerr.AddErrorContextProvider(ctx, hostProvider)
instrumentHostLogger(ctx, host.ID)
if ac, ok := authz_ctx.FromContext(ctx); ok {
ac.SetAuthnMethod(authz_ctx.AuthnHostToken)
@ -193,6 +202,10 @@ func authenticatedOrbitHost(
}
ctx = hostctx.NewContext(ctx, host)
// Register host as error context provider for ctxerr enrichment
hostProvider := &hostctx.HostAttributeProvider{Host: host}
ctx = ctxerr.AddErrorContextProvider(ctx, hostProvider)
instrumentHostLogger(ctx, host.ID)
if ac, ok := authz_ctx.FromContext(ctx); ok {
ac.SetAuthnMethod(authz_ctx.AuthnOrbitToken)

View file

@ -11,7 +11,6 @@ import (
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/service/middleware/auth"
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
kitlog "github.com/go-kit/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -188,7 +187,7 @@ func TestAuthenticatedHost(t *testing.T) {
r := &testNodeKeyRequest{NodeKey: tt.nodeKey}
_, err := endpoint(ctx, r)
if tt.shouldErr {
assert.IsType(t, &endpoint_utils.OsqueryError{}, err)
assert.IsType(t, &OsqueryError{}, err)
} else {
assert.Nil(t, err)
}

View file

@ -12,6 +12,7 @@ import (
"github.com/fleetdm/fleet/v4/server/contexts/capabilities"
"github.com/fleetdm/fleet/v4/server/fleet"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
"github.com/fleetdm/fleet/v4/server/service/middleware/auth"
eu "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
"github.com/go-kit/kit/endpoint"
@ -21,7 +22,31 @@ import (
)
func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
return eu.MakeDecoder(iface, jsonDecode, parseCustomTags, isBodyDecoder, decodeBody)
return eu.MakeDecoder(iface, jsonDecode, parseCustomTags, isBodyDecoder, decodeBody, fleetQueryDecoder)
}
// fleetQueryDecoder handles fleet-specific query parameter decoding, such as
// converting the order_direction string to the fleet.OrderDirection int type.
func fleetQueryDecoder(queryTagName, queryVal string, field reflect.Value) (bool, error) {
// Only handle int fields for order_direction
if field.Kind() != reflect.Int {
return false, nil
}
switch queryTagName {
case "order_direction", "inherited_order_direction":
var direction int
switch queryVal {
case "desc":
direction = int(fleet.OrderDescending)
case "asc":
direction = int(fleet.OrderAscending)
default:
return false, &fleet.BadRequestError{Message: "unknown order_direction: " + queryVal}
}
field.SetInt(int64(direction))
return true, nil
}
return false, nil
}
// A value that implements bodyDecoder takes control of decoding the request body.
@ -100,7 +125,7 @@ type endpointer struct {
func (e *endpointer) CallHandlerFunc(f handlerFunc, ctx context.Context, request interface{},
svc interface{},
) (fleet.Errorer, error) {
) (platform_http.Errorer, error) {
return f(ctx, request, svc.(fleet.Service))
}

View file

@ -290,7 +290,7 @@ func TestEndpointer(t *testing.T) {
auth.SetRequestsContexts(svc),
),
kithttp.ServerErrorHandler(&endpoint_utils.ErrorHandler{Logger: kitlog.NewNopLogger()}),
kithttp.ServerErrorEncoder(endpoint_utils.EncodeError),
kithttp.ServerErrorEncoder(fleetErrorEncoder),
kithttp.ServerAfter(
kithttp.SetContentType("application/json; charset=utf-8"),
log.LogRequestEnd(kitlog.NewNopLogger()),
@ -410,7 +410,7 @@ func TestEndpointerCustomMiddleware(t *testing.T) {
auth.SetRequestsContexts(svc),
),
kithttp.ServerErrorHandler(&endpoint_utils.ErrorHandler{Logger: kitlog.NewNopLogger()}),
kithttp.ServerErrorEncoder(endpoint_utils.EncodeError),
kithttp.ServerErrorEncoder(fleetErrorEncoder),
kithttp.ServerAfter(
kithttp.SetContentType("application/json; charset=utf-8"),
log.LogRequestEnd(kitlog.NewNopLogger()),

View file

@ -110,7 +110,7 @@ func MakeHandler(
auth.SetRequestsContexts(svc),
),
kithttp.ServerErrorHandler(&endpoint_utils.ErrorHandler{Logger: logger}),
kithttp.ServerErrorEncoder(endpoint_utils.EncodeError),
kithttp.ServerErrorEncoder(fleetErrorEncoder),
kithttp.ServerAfter(
kithttp.SetContentType("application/json; charset=utf-8"),
log.LogRequestEnd(logger),

View file

@ -30,7 +30,6 @@ import (
"github.com/fleetdm/fleet/v4/server/mdm/assets"
mdmlifecycle "github.com/fleetdm/fleet/v4/server/mdm/lifecycle"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
"github.com/fleetdm/fleet/v4/server/worker"
"github.com/go-kit/log/level"
"github.com/gocarina/gocsv"
@ -2319,7 +2318,7 @@ func (r hostsReportResponse) HijackRender(ctx context.Context, w http.ResponseWr
var buf bytes.Buffer
if err := gocsv.Marshal(r.Hosts, &buf); err != nil {
logging.WithErr(ctx, err)
endpoint_utils.EncodeError(ctx, ctxerr.New(ctx, "failed to generate CSV file"), w)
encodeError(ctx, ctxerr.New(ctx, "failed to generate CSV file"), w)
return
}
@ -2331,7 +2330,7 @@ func (r hostsReportResponse) HijackRender(ctx context.Context, w http.ResponseWr
recs, err := csv.NewReader(&buf).ReadAll()
if err != nil {
logging.WithErr(ctx, err)
endpoint_utils.EncodeError(ctx, ctxerr.New(ctx, "failed to generate CSV file"), w)
encodeError(ctx, ctxerr.New(ctx, "failed to generate CSV file"), w)
return
}
@ -2352,7 +2351,7 @@ func (r hostsReportResponse) HijackRender(ctx context.Context, w http.ResponseWr
// duplicating the list of columns from the Host's struct tags to a
// map and keep this in sync, for what is essentially a programmer
// mistake that should be caught and corrected early.
endpoint_utils.EncodeError(ctx, &fleet.BadRequestError{Message: fmt.Sprintf("invalid column name: %q", col)}, w)
encodeError(ctx, &fleet.BadRequestError{Message: fmt.Sprintf("invalid column name: %q", col)}, w)
return
}
outRows[i] = append(outRows[i], rec[colIx])

View file

@ -9,6 +9,8 @@ import (
"github.com/fleetdm/fleet/v4/ee/server/service/hostidentity/httpsig"
"github.com/fleetdm/fleet/v4/server/contexts/authz"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/contexts/token"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -76,6 +78,10 @@ func AuthenticatedUser(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpo
}
ctx = viewer.NewContext(ctx, *v)
// Register viewer as error context provider for ctxerr enrichment
ctx = ctxerr.AddErrorContextProvider(ctx, v)
// Register viewer as user emailer for logging
ctx = logging.WithUserEmailer(ctx, v)
if ac, ok := authz.FromContext(ctx); ok {
ac.SetAuthnMethod(authz.AuthnUserToken)
}
@ -123,6 +129,10 @@ func AuthenticatedUserMiddleware(svc fleet.Service, errHandler errorHandler, nex
}
ctx := viewer.NewContext(r.Context(), *v)
// Register viewer as error context provider for ctxerr enrichment
ctx = ctxerr.AddErrorContextProvider(ctx, v)
// Register viewer as user emailer for logging
ctx = logging.WithUserEmailer(ctx, v)
if ac, ok := authz.FromContext(r.Context()); ok {
ac.SetAuthnMethod(authz.AuthnUserToken)
}

View file

@ -4,6 +4,7 @@ import (
"context"
"net/http"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/contexts/token"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
@ -20,6 +21,10 @@ func SetRequestsContexts(svc fleet.Service) kithttp.RequestFunc {
v, err := AuthViewer(ctx, string(bearer), svc)
if err == nil {
ctx = viewer.NewContext(ctx, *v)
// Register viewer as error context provider for ctxerr enrichment
ctx = ctxerr.AddErrorContextProvider(ctx, v)
// Register viewer as user emailer for logging
ctx = logging.WithUserEmailer(ctx, v)
}
}

View file

@ -8,9 +8,8 @@ import (
"context"
"errors"
"github.com/fleetdm/fleet/v4/server/authz"
authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz"
"github.com/fleetdm/fleet/v4/server/fleet"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
"github.com/go-kit/kit/endpoint"
)
@ -32,13 +31,13 @@ func (m *Middleware) AuthzCheck() endpoint.Middleware {
// If authentication check failed, return that error (so that we log
// appropriately).
var authFailedError *fleet.AuthFailedError
var authRequiredError *fleet.AuthRequiredError
var authHeaderRequiredError *fleet.AuthHeaderRequiredError
var authFailedError *platform_http.AuthFailedError
var authRequiredError *platform_http.AuthRequiredError
var authHeaderRequiredError *platform_http.AuthHeaderRequiredError
if errors.As(err, &authFailedError) ||
errors.As(err, &authRequiredError) ||
errors.As(err, &authHeaderRequiredError) ||
errors.Is(err, fleet.ErrPasswordResetRequired) {
errors.Is(err, platform_http.ErrPasswordResetRequired) {
return nil, err
}
@ -52,7 +51,7 @@ func (m *Middleware) AuthzCheck() endpoint.Middleware {
// marshal to a generic error and log that the check was missed.
if !authzctx.Checked() {
// Getting to here means there is an authorization-related bug in our code.
return nil, authz.CheckMissingWithResponse(response)
return nil, platform_http.CheckMissingWithResponse(response)
}
return response, err

View file

@ -5,7 +5,7 @@ import (
"testing"
"github.com/fleetdm/fleet/v4/server/contexts/authz"
"github.com/fleetdm/fleet/v4/server/fleet"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -33,7 +33,7 @@ func TestAuthzCheckAuthFailed(t *testing.T) {
checker := NewMiddleware()
check := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, fleet.NewAuthFailedError("failed")
return nil, platform_http.NewAuthFailedError("failed")
}
check = checker.AuthzCheck()(check)
@ -48,7 +48,7 @@ func TestAuthzCheckAuthRequired(t *testing.T) {
checker := NewMiddleware()
check := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, fleet.NewAuthRequiredError("required")
return nil, platform_http.NewAuthRequiredError("required")
}
check = checker.AuthzCheck()(check)

View file

@ -18,7 +18,7 @@ import (
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/license"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/fleet"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
"github.com/fleetdm/fleet/v4/server/service/middleware/authzcheck"
"github.com/fleetdm/fleet/v4/server/service/middleware/ratelimit"
"github.com/go-kit/kit/endpoint"
@ -86,7 +86,7 @@ func BadRequestErr(publicMsg string, internalErr error) error {
if errors.As(internalErr, &opErr) {
return fmt.Errorf(publicMsg+", internal: %w", internalErr)
}
return &fleet.BadRequestError{
return &platform_http.BadRequestError{
Message: publicMsg,
InternalErr: internalErr,
}
@ -169,7 +169,11 @@ func DecodeURLTagValue(r *http.Request, field reflect.Value, urlTagValue string,
return nil
}
func DecodeQueryTagValue(r *http.Request, fp fieldPair) error {
// DomainQueryFieldDecoder decodes a query parameter value into the target field.
// It returns true if it handled the field, false if default handling should be used.
type DomainQueryFieldDecoder func(queryTagName, queryVal string, field reflect.Value) (handled bool, err error)
func DecodeQueryTagValue(r *http.Request, fp fieldPair, customDecoder DomainQueryFieldDecoder) error {
queryTagValue, ok := fp.Sf.Tag.Lookup("query")
if ok {
@ -185,7 +189,7 @@ func DecodeQueryTagValue(r *http.Request, fp fieldPair) error {
if optional {
return nil
}
return &fleet.BadRequestError{Message: fmt.Sprintf("Param %s is required", queryTagValue)}
return &platform_http.BadRequestError{Message: fmt.Sprintf("Param %s is required", queryTagValue)}
}
field := fp.V
if field.Kind() == reflect.Ptr {
@ -193,6 +197,18 @@ func DecodeQueryTagValue(r *http.Request, fp fieldPair) error {
field.Set(reflect.New(field.Type().Elem()))
field = field.Elem()
}
// Try custom decoder first if provided
if customDecoder != nil {
handled, err := customDecoder(queryTagValue, queryVal, field)
if err != nil {
return err
}
if handled {
return nil
}
}
switch field.Kind() {
case reflect.String:
field.SetString(queryVal)
@ -211,22 +227,9 @@ func DecodeQueryTagValue(r *http.Request, fp fieldPair) error {
case reflect.Bool:
field.SetBool(queryVal == "1" || queryVal == "true")
case reflect.Int:
queryValInt := 0
switch queryTagValue {
case "order_direction", "inherited_order_direction":
switch queryVal {
case "desc":
queryValInt = int(fleet.OrderDescending)
case "asc":
queryValInt = int(fleet.OrderAscending)
default:
return &fleet.BadRequestError{Message: "unknown order_direction: " + queryVal}
}
default:
queryValInt, err = strconv.Atoi(queryVal)
if err != nil {
return BadRequestErr("parsing int from query", err)
}
queryValInt, err := strconv.Atoi(queryVal)
if err != nil {
return BadRequestErr("parsing int from query", err)
}
field.SetInt(int64(queryValInt))
default:
@ -297,17 +300,17 @@ func (h *ErrorHandler) Handle(ctx context.Context, err error) {
logger = log.With(logger, "took", time.Since(startTime))
}
var ewi fleet.ErrWithInternal
var ewi platform_http.ErrWithInternal
if errors.As(err, &ewi) {
logger = log.With(logger, "internal", ewi.Internal())
}
var ewlf fleet.ErrWithLogFields
var ewlf platform_http.ErrWithLogFields
if errors.As(err, &ewlf) {
logger = log.With(logger, ewlf.LogFields()...)
}
var uuider fleet.ErrorUUIDer
var uuider platform_http.ErrorUUIDer
if errors.As(err, &uuider) {
logger = log.With(logger, "uuid", uuider.UUID())
}
@ -338,17 +341,14 @@ type requestValidator interface {
}
// MakeDecoder creates a decoder for the type for the struct passed on. If the
// struct has at least 1 json tag it'll unmarshall the body. If the struct has
// a `url` tag with value list_options it'll gather fleet.ListOptions from the
// URL (similarly for host_options, carve_options, user_options that derive
// from the common list_options). Note that these behaviors do not work for embedded structs.
// struct has at least 1 json tag it'll unmarshall the body. Custom `url` tag
// values can be handled by providing a parseCustomTags function. Note that
// these behaviors do not work for embedded structs.
//
// Finally, any other `url` tag will be treated as a path variable (of the form
// Any other `url` tag will be treated as a path variable (of the form
// /path/{name} in the route's path) from the URL path pattern, and it'll be
// decoded and set accordingly. Variables can be optional by setting the tag as
// follows: `url:"some-id,optional"`.
// The "list_options" are optional by default and it'll ignore the optional
// portion of the tag.
//
// If iface implements the RequestDecoder interface, it returns a function that
// calls iface.DecodeRequest(ctx, r) - i.e. the value itself fully controls its
@ -357,12 +357,16 @@ type requestValidator interface {
// If iface implements the bodyDecoder interface, it calls iface.DecodeBody
// after having decoded any non-body fields (such as url and query parameters)
// into the struct.
//
// The customQueryDecoder parameter allows services to inject domain-specific
// query parameter decoding logic.
func MakeDecoder(
iface interface{},
jsonUnmarshal func(body io.Reader, req any) error,
parseCustomTags func(urlTagValue string, r *http.Request, field reflect.Value) (bool, error),
isBodyDecoder func(reflect.Value) bool,
decodeBody func(ctx context.Context, r *http.Request, v reflect.Value, body io.Reader) error,
customQueryDecoder DomainQueryFieldDecoder,
) kithttp.DecodeRequestFunc {
if iface == nil {
return func(ctx context.Context, r *http.Request) (interface{}, error) {
@ -441,16 +445,16 @@ func MakeDecoder(
_, jsonExpected := fp.Sf.Tag.Lookup("json")
if jsonExpected && nilBody {
return nil, &fleet.BadRequestError{Message: "Expected JSON Body"}
return nil, &platform_http.BadRequestError{Message: "Expected JSON Body"}
}
isContentJson := r.Header.Get("Content-Type") == "application/json"
isCrossSite := r.Header.Get("Origin") != "" || r.Header.Get("Referer") != ""
if jsonExpected && isCrossSite && !isContentJson {
return nil, fleet.NewUserMessageError(errors.New("Expected Content-Type \"application/json\""), http.StatusUnsupportedMediaType)
return nil, platform_http.NewUserMessageError(errors.New("Expected Content-Type \"application/json\""), http.StatusUnsupportedMediaType)
}
err = DecodeQueryTagValue(r, fp)
err = DecodeQueryTagValue(r, fp, customQueryDecoder)
if err != nil {
return nil, err
}
@ -471,7 +475,7 @@ func MakeDecoder(
return nil, err
}
if val && !fp.V.IsZero() {
return nil, &fleet.BadRequestError{Message: fmt.Sprintf(
return nil, &platform_http.BadRequestError{Message: fmt.Sprintf(
"option %s requires a premium license",
fp.Sf.Name,
)}
@ -509,17 +513,19 @@ func WriteBrowserSecurityHeaders(w http.ResponseWriter) {
}
type CommonEndpointer[H any] struct {
EP Endpointer[H]
MakeDecoderFn func(iface interface{}) kithttp.DecodeRequestFunc
EncodeFn kithttp.EncodeResponseFunc
Opts []kithttp.ServerOption
AuthMiddleware endpoint.Middleware
Router *mux.Router
Versions []string
EP Endpointer[H]
MakeDecoderFn func(iface any) kithttp.DecodeRequestFunc
EncodeFn kithttp.EncodeResponseFunc
Opts []kithttp.ServerOption
Router *mux.Router
Versions []string
// CustomMiddleware are middlewares that run before AuthMiddleware.
// AuthMiddleware is a pre-built authentication middleware.
AuthMiddleware endpoint.Middleware
// CustomMiddleware are middlewares that run before authentication.
CustomMiddleware []endpoint.Middleware
// CustomMiddlewareAfterAuth are middlewares that run after AuthMiddleware.
// CustomMiddlewareAfterAuth are middlewares that run after authentication.
CustomMiddlewareAfterAuth []endpoint.Middleware
startingAtVersion string
@ -529,8 +535,8 @@ type CommonEndpointer[H any] struct {
}
type Endpointer[H any] interface {
CallHandlerFunc(f H, ctx context.Context, request interface{}, svc interface{}) (fleet.Errorer, error)
Service() interface{}
CallHandlerFunc(f H, ctx context.Context, request any, svc any) (platform_http.Errorer, error)
Service() any
}
func (e *CommonEndpointer[H]) POST(path string, f H, v interface{}) {
@ -730,6 +736,7 @@ func EncodeCommonResponse(
w http.ResponseWriter,
response interface{},
jsonMarshal func(w http.ResponseWriter, response interface{}) error,
domainErrorEncoder DomainErrorEncoder,
) error {
if cs, ok := response.(cookieSetter); ok {
cs.SetCookies(ctx, w)
@ -747,8 +754,8 @@ func EncodeCommonResponse(
return err
}
if e, ok := response.(fleet.Errorer); ok && e.Error() != nil {
EncodeError(ctx, e.Error(), w)
if e, ok := response.(platform_http.Errorer); ok && e.Error() != nil {
EncodeError(ctx, e.Error(), w, domainErrorEncoder)
return nil
}

View file

@ -8,7 +8,7 @@ import (
"testing"
authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz"
"github.com/fleetdm/fleet/v4/server/fleet"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
"github.com/go-kit/kit/endpoint"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/gorilla/mux"
@ -16,7 +16,7 @@ import (
)
// testHandlerFunc is a handler function type used for testing.
type testHandlerFunc func(ctx context.Context, request any, svc any) (fleet.Errorer, error)
type testHandlerFunc func(ctx context.Context, request any) (platform_http.Errorer, error)
func TestCustomMiddlewareAfterAuth(t *testing.T) {
var (
@ -82,7 +82,7 @@ func TestCustomMiddlewareAfterAuth(t *testing.T) {
},
Router: r,
}
ce.handleEndpoint("/", func(ctx context.Context, request interface{}, svc any) (fleet.Errorer, error) {
ce.handleEndpoint("/", func(ctx context.Context, request any) (platform_http.Errorer, error) {
fmt.Printf("handler\n")
return nopResponse{}, nil
}, nil, "GET")
@ -116,10 +116,10 @@ func (n nopResponse) Error() error {
type nopEP struct{}
func (n nopEP) CallHandlerFunc(_ testHandlerFunc, _ context.Context, _ any, _ any) (fleet.Errorer, error) {
return nopResponse{}, nil
func (n nopEP) CallHandlerFunc(f testHandlerFunc, ctx context.Context, request any, svc any) (platform_http.Errorer, error) {
return f(ctx, request)
}
func (n nopEP) Service() interface{} {
func (n nopEP) Service() any {
return nil
}

View file

@ -4,13 +4,12 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"strconv"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/go-sql-driver/mysql"
)
@ -18,6 +17,11 @@ import (
// ErrBadRoute is used for mux errors
var ErrBadRoute = errors.New("bad route")
// DomainErrorEncoder handles domain-specific error encoding.
// It returns true if it handled the error, false if default handling should be used.
// The encoder should write the appropriate status code and response body.
type DomainErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter, enc *json.Encoder, jsonErr *JsonError) (handled bool)
type JsonError struct {
Message string `json:"message"`
Code int `json:"code,omitempty"`
@ -67,8 +71,10 @@ type conflictErrorInterface interface {
IsConflict() bool
}
// EncodeError encodes error and status header to the client
func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
// EncodeError encodes error and status header to the client.
// The domainEncoder parameter allows services to inject domain-specific error
// handling. If nil, only generic error handling is performed.
func EncodeError(ctx context.Context, err error, w http.ResponseWriter, domainEncoder DomainErrorEncoder) {
ctxerr.Handle(ctx, err)
origErr := err
@ -78,7 +84,7 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
err = ctxerr.Cause(err)
var uuid string
if uuidErr, ok := err.(fleet.ErrorUUIDer); ok {
if uuidErr, ok := err.(platform_http.ErrorUUIDer); ok {
uuid = uuidErr.UUID()
}
@ -86,6 +92,13 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
UUID: uuid,
}
// Try domain-specific error encoder first
if domainEncoder != nil {
if handled := domainEncoder(ctx, err, w, enc, &jsonErr); handled {
return
}
}
switch e := err.(type) {
case validationErrorInterface:
if statusErr, ok := e.(interface{ Status() int }); ok {
@ -99,36 +112,6 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
jsonErr.Message = "Permission Denied"
jsonErr.Errors = e.PermissionError()
w.WriteHeader(http.StatusForbidden)
case MailError:
jsonErr.Message = "Mail Error"
jsonErr.Errors = e.MailError()
w.WriteHeader(http.StatusInternalServerError)
case *OsqueryError:
// osquery expects to receive the node_invalid key when a TLS
// request provides an invalid node_key for authentication. It
// doesn't use the error message provided, but we provide this
// for debugging purposes (and perhaps osquery will use this
// error message in the future).
errMap := map[string]interface{}{
"error": e.Error(),
"uuid": uuid,
}
if e.NodeInvalid() { //nolint:gocritic // ignore ifElseChain
w.WriteHeader(http.StatusUnauthorized)
errMap["node_invalid"] = true
} else if e.Status() != 0 {
w.WriteHeader(e.Status())
} else {
// TODO: osqueryError is not always the result of an internal error on
// our side, it is also used to represent a client error (invalid data,
// e.g. malformed json, carve too large, etc., so 4xx), are we returning
// a 500 because of some osquery-specific requirement?
w.WriteHeader(http.StatusInternalServerError)
}
enc.Encode(errMap) //nolint:errcheck
return
case NotFoundErrorInterface:
jsonErr.Message = "Resource Not Found"
jsonErr.Errors = baseError(e.Error())
@ -153,7 +136,7 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
statusCode = http.StatusConflict
}
w.WriteHeader(statusCode)
case *fleet.Error:
case *platform_http.Error:
jsonErr.Message = e.Error()
jsonErr.Code = e.Code
w.WriteHeader(http.StatusUnprocessableEntity)
@ -168,7 +151,7 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
enc.Encode(jsonErr) //nolint:errcheck
return
}
if fleet.IsForeignKey(err) {
if platform_http.IsForeignKey(err) {
jsonErr.Message = "Validation Failed"
jsonErr.Errors = baseError(err.Error())
w.WriteHeader(http.StatusUnprocessableEntity)
@ -186,14 +169,14 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
// See header documentation
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After)
var ewra fleet.ErrWithRetryAfter
var ewra platform_http.ErrWithRetryAfter
if errors.As(err, &ewra) {
w.Header().Add("Retry-After", strconv.Itoa(ewra.RetryAfter()))
}
msg := err.Error()
reason := err.Error()
var ume *fleet.UserMessageError
var ume *platform_http.UserMessageError
if errors.As(err, &ume) {
if text := http.StatusText(status); text != "" {
msg = text
@ -208,53 +191,3 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
enc.Encode(jsonErr) //nolint:errcheck
}
// MailError is set when an error performing mail operations
type MailError struct {
Message string
}
func (e MailError) Error() string {
return fmt.Sprintf("a mail error occurred: %s", e.Message)
}
func (e MailError) MailError() []map[string]string {
return []map[string]string{
{
"name": "base",
"reason": e.Message,
},
}
}
// OsqueryError is the error returned to osquery agents.
type OsqueryError struct {
message string
nodeInvalid bool
StatusCode int
fleet.ErrorWithUUID
}
var _ fleet.ErrorUUIDer = (*OsqueryError)(nil)
// Error implements the error interface.
func (e *OsqueryError) Error() string {
return e.message
}
// NodeInvalid returns whether the error returned to osquery
// should contain the node_invalid property.
func (e *OsqueryError) NodeInvalid() bool {
return e.nodeInvalid
}
func (e *OsqueryError) Status() int {
return e.StatusCode
}
func NewOsqueryError(message string, nodeInvalid bool) *OsqueryError {
return &OsqueryError{
message: message,
nodeInvalid: nodeInvalid,
}
}

View file

@ -6,7 +6,7 @@ import (
"net/http/httptest"
"testing"
"github.com/fleetdm/fleet/v4/server/fleet"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
"github.com/stretchr/testify/assert"
)
@ -25,7 +25,7 @@ type newAndExciting struct{}
func (newAndExciting) Error() string { return "" }
type notFoundError struct {
fleet.ErrorWithUUID
platform_http.ErrorWithUUID
}
func (e *notFoundError) Error() string {
@ -36,6 +36,32 @@ func (e *notFoundError) IsNotFound() bool {
return true
}
// validationError is a test implementation of validationErrorInterface.
type validationError struct {
errors []map[string]string
}
func (e validationError) Error() string {
return "validation failed"
}
func (e validationError) Invalid() []map[string]string {
return e.errors
}
// permissionError is a test implementation of permissionErrorInterface.
type permissionError struct {
message string
}
func (e permissionError) Error() string {
return e.message
}
func (e permissionError) PermissionError() []map[string]string {
return nil
}
func TestHandlesErrorsCode(t *testing.T) {
errorTests := []struct {
name string
@ -44,12 +70,12 @@ func TestHandlesErrorsCode(t *testing.T) {
}{
{
"validation",
fleet.NewInvalidArgumentError("a", "b"),
validationError{errors: []map[string]string{{"name": "a", "reason": "b"}}},
http.StatusUnprocessableEntity,
},
{
"permission",
fleet.NewPermissionError("a"),
permissionError{message: "a"},
http.StatusForbidden,
},
{
@ -57,21 +83,6 @@ func TestHandlesErrorsCode(t *testing.T) {
foreignKeyError{},
http.StatusUnprocessableEntity,
},
{
"mail error",
MailError{},
http.StatusInternalServerError,
},
{
"osquery error - invalid node",
&OsqueryError{nodeInvalid: true},
http.StatusUnauthorized,
},
{
"osquery error - valid node",
&OsqueryError{},
http.StatusInternalServerError,
},
{
"data not found",
&notFoundError{},
@ -84,7 +95,7 @@ func TestHandlesErrorsCode(t *testing.T) {
},
{
"status coder",
fleet.NewAuthFailedError(""),
platform_http.NewAuthFailedError(""),
http.StatusUnauthorized,
},
{
@ -97,7 +108,7 @@ func TestHandlesErrorsCode(t *testing.T) {
for _, tt := range errorTests {
t.Run(tt.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
EncodeError(context.Background(), tt.err, recorder)
EncodeError(context.Background(), tt.err, recorder, nil)
assert.Equal(t, recorder.Code, tt.code)
})
}

View file

@ -21,7 +21,6 @@ import (
microsoft_mdm "github.com/fleetdm/fleet/v4/server/mdm/microsoft"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/service/contract"
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
"github.com/fleetdm/fleet/v4/server/worker"
"github.com/go-kit/log/level"
)
@ -65,7 +64,7 @@ func (r EnrollOrbitResponse) HijackRender(ctx context.Context, w http.ResponseWr
enc.SetIndent("", " ")
if err := enc.Encode(r); err != nil {
endpoint_utils.EncodeError(ctx, newOsqueryError(fmt.Sprintf("orbit enroll failed: %s", err)), w)
encodeError(ctx, newOsqueryError(fmt.Sprintf("orbit enroll failed: %s", err)), w)
}
}

View file

@ -25,7 +25,6 @@ import (
"github.com/fleetdm/fleet/v4/server/pubsub"
"github.com/fleetdm/fleet/v4/server/service/conditional_access_microsoft_proxy"
"github.com/fleetdm/fleet/v4/server/service/contract"
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
"github.com/fleetdm/fleet/v4/server/service/osquery_utils"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/go-kit/log"
@ -34,12 +33,12 @@ import (
"golang.org/x/exp/slices"
)
func newOsqueryErrorWithInvalidNode(msg string) *endpoint_utils.OsqueryError {
return endpoint_utils.NewOsqueryError(msg, true)
func newOsqueryErrorWithInvalidNode(msg string) *OsqueryError {
return NewOsqueryError(msg, true)
}
func newOsqueryError(msg string) *endpoint_utils.OsqueryError {
return endpoint_utils.NewOsqueryError(msg, false)
func newOsqueryError(msg string) *OsqueryError {
return NewOsqueryError(msg, false)
}
func (svc *Service) AuthenticateHost(ctx context.Context, nodeKey string) (*fleet.Host, bool, error) {
@ -138,7 +137,7 @@ func (svc *Service) EnrollOsquery(ctx context.Context, enrollSecret, hostIdentif
if !canEnroll {
deviceCount := "unknown"
if lic, _ := license.FromContext(ctx); lic != nil {
deviceCount = strconv.Itoa(lic.DeviceCount)
deviceCount = strconv.Itoa(lic.GetDeviceCount())
}
return "", newOsqueryErrorWithInvalidNode(fmt.Sprintf("enroll host failed: maximum number of hosts reached: %s", deviceCount))
}

View file

@ -33,7 +33,6 @@ import (
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/pubsub"
"github.com/fleetdm/fleet/v4/server/service/async"
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
"github.com/fleetdm/fleet/v4/server/service/osquery_utils"
"github.com/fleetdm/fleet/v4/server/service/redis_policy_set"
"github.com/go-kit/log"
@ -992,7 +991,7 @@ func TestSubmitResultLogsFail(t *testing.T) {
// Expect an error when unable to write to logging destination.
err = svc.SubmitResultLogs(ctx, results)
require.Error(t, err)
assert.Equal(t, http.StatusRequestEntityTooLarge, err.(*endpoint_utils.OsqueryError).Status())
assert.Equal(t, http.StatusRequestEntityTooLarge, err.(*OsqueryError).Status())
}
func TestGetQueryNameAndTeamIDFromResult(t *testing.T) {
@ -2908,7 +2907,7 @@ func TestAuthenticationErrors(t *testing.T) {
_, _, err := svc.AuthenticateHost(ctx, "")
require.Error(t, err)
require.True(t, err.(*endpoint_utils.OsqueryError).NodeInvalid())
require.True(t, err.(*OsqueryError).NodeInvalid())
ms.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) {
return &fleet.Host{ID: 1, HasHostIdentityCert: ptr.Bool(false)}, nil
@ -2929,7 +2928,7 @@ func TestAuthenticationErrors(t *testing.T) {
_, _, err = svc.AuthenticateHost(ctx, "foo")
require.Error(t, err)
require.True(t, err.(*endpoint_utils.OsqueryError).NodeInvalid())
require.True(t, err.(*OsqueryError).NodeInvalid())
// return other error
ms.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) {
@ -2938,7 +2937,7 @@ func TestAuthenticationErrors(t *testing.T) {
_, _, err = svc.AuthenticateHost(ctx, "foo")
require.NotNil(t, err)
require.False(t, err.(*endpoint_utils.OsqueryError).NodeInvalid())
require.False(t, err.(*OsqueryError).NodeInvalid())
}
func TestGetHostIdentifier(t *testing.T) {

View file

@ -12,7 +12,6 @@ import (
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mail"
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
)
func (svc *Service) NewAppConfig(ctx context.Context, p fleet.AppConfig) (*fleet.AppConfig, error) {
@ -65,7 +64,7 @@ func (svc *Service) sendTestEmail(ctx context.Context, config *fleet.AppConfig)
}
if err := mail.Test(svc.mailService, testMail); err != nil {
return endpoint_utils.MailError{Message: err.Error()}
return MailError{Message: err.Error()}
}
return nil
}
@ -83,12 +82,14 @@ func (svc *Service) License(ctx context.Context) (*fleet.LicenseInfo, error) {
}
}
lic, _ := license.FromContext(ctx)
licChecker, _ := license.FromContext(ctx)
// Type assert to get the concrete type for modification and return
lic, _ := licChecker.(*fleet.LicenseInfo)
// Currently we use the presence of Microsoft Compliance Partner settings
// (only configured in cloud instances) to determine if a Fleet instance
// is a cloud managed instance.
if svc.config.MicrosoftCompliancePartner.IsSet() {
if lic != nil && svc.config.MicrosoftCompliancePartner.IsSet() {
lic.ManagedCloud = true
}

View file

@ -31,14 +31,15 @@ func (svc *Service) CreateInitialUser(ctx context.Context, p fleet.UserPayload)
}
func (svc *Service) NewUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
license, _ := license.FromContext(ctx)
if license == nil {
licChecker, _ := license.FromContext(ctx)
lic, _ := licChecker.(*fleet.LicenseInfo)
if lic == nil {
return nil, ctxerr.New(ctx, "license not found")
}
if err := fleet.ValidateUserRoles(true, p, *license); err != nil {
if err := fleet.ValidateUserRoles(true, p, *lic); err != nil {
return nil, ctxerr.Wrap(ctx, err, "validate role")
}
if !license.IsPremium() {
if !lic.IsPremium() {
p.MFAEnabled = ptr.Bool(false)
}

View file

@ -16,7 +16,7 @@ import (
)
func encodeResponse(ctx context.Context, w http.ResponseWriter, response interface{}) error {
return endpoint_utils.EncodeCommonResponse(ctx, w, response, jsonMarshal)
return endpoint_utils.EncodeCommonResponse(ctx, w, response, jsonMarshal, FleetErrorEncoder)
}
func jsonMarshal(w http.ResponseWriter, response interface{}) error {

View file

@ -0,0 +1,110 @@
package service
import (
"context"
"encoding/json"
"fmt"
"net/http"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
)
// FleetErrorEncoder handles fleet-specific error encoding for MailError
// and OsqueryError.
func FleetErrorEncoder(ctx context.Context, err error, w http.ResponseWriter, enc *json.Encoder, jsonErr *endpoint_utils.JsonError) bool {
switch e := err.(type) {
case MailError:
jsonErr.Message = "Mail Error"
jsonErr.Errors = []map[string]string{
{
"name": "base",
"reason": e.Message,
},
}
w.WriteHeader(http.StatusInternalServerError)
enc.Encode(jsonErr) //nolint:errcheck
return true
case *OsqueryError:
// osquery expects to receive the node_invalid key when a TLS
// request provides an invalid node_key for authentication. It
// doesn't use the error message provided, but we provide this
// for debugging purposes (and perhaps osquery will use this
// error message in the future).
errMap := map[string]any{
"error": e.Error(),
"uuid": jsonErr.UUID,
}
if e.NodeInvalid() { //nolint:gocritic // ignore ifElseChain
w.WriteHeader(http.StatusUnauthorized)
errMap["node_invalid"] = true
} else if e.Status() != 0 {
w.WriteHeader(e.Status())
} else {
// TODO: osqueryError is not always the result of an internal error on
// our side, it is also used to represent a client error (invalid data,
// e.g. malformed json, carve too large, etc., so 4xx), are we returning
// a 500 because of some osquery-specific requirement?
w.WriteHeader(http.StatusInternalServerError)
}
enc.Encode(errMap) //nolint:errcheck
return true
}
return false
}
// MailError is set when an error performing mail operations
type MailError struct {
Message string
}
func (e MailError) Error() string {
return fmt.Sprintf("a mail error occurred: %s", e.Message)
}
// OsqueryError is the error returned to osquery agents.
type OsqueryError struct {
message string
nodeInvalid bool
StatusCode int
platform_http.ErrorWithUUID
}
var _ platform_http.ErrorUUIDer = (*OsqueryError)(nil)
// Error implements the error interface.
func (e *OsqueryError) Error() string {
return e.message
}
// NodeInvalid returns whether the error returned to osquery
// should contain the node_invalid property.
func (e *OsqueryError) NodeInvalid() bool {
return e.nodeInvalid
}
func (e *OsqueryError) Status() int {
return e.StatusCode
}
func NewOsqueryError(message string, nodeInvalid bool) *OsqueryError {
return &OsqueryError{
message: message,
nodeInvalid: nodeInvalid,
}
}
// encodeError is a convenience function that calls endpoint_utils.EncodeError
// with the FleetErrorEncoder. Use this for direct error encoding in handlers.
func encodeError(ctx context.Context, err error, w http.ResponseWriter) {
endpoint_utils.EncodeError(ctx, err, w, FleetErrorEncoder)
}
// fleetErrorEncoder is an adapter that wraps endpoint_utils.EncodeError with
// FleetErrorEncoder for use as a kithttp.ErrorEncoder.
func fleetErrorEncoder(ctx context.Context, err error, w http.ResponseWriter) {
endpoint_utils.EncodeError(ctx, err, w, FleetErrorEncoder)
}

View file

@ -433,11 +433,12 @@ func (svc *Service) ModifyUser(ctx context.Context, userID uint, p fleet.UserPay
if err := svc.authz.Authorize(ctx, user, fleet.ActionWriteRole); err != nil {
return nil, err
}
license, _ := license.FromContext(ctx)
if license == nil {
licChecker, _ := license.FromContext(ctx)
lic, _ := licChecker.(*fleet.LicenseInfo)
if lic == nil {
return nil, ctxerr.New(ctx, "license not found")
}
if err := fleet.ValidateUserRoles(false, p, *license); err != nil {
if err := fleet.ValidateUserRoles(false, p, *lic); err != nil {
return nil, ctxerr.Wrap(ctx, err, "validate role")
}
}