mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 13:37:30 +00:00
Refactor endpoint_utils for modularization (#36484)
Resolves #37192 Separating generic endpoint_utils middleware logic from domain-specific business logic. New bounded contexts would share the generic logic and implement their own domain-specific logic. The two approaches used in this PR are: - Use common `platform` types - Use interfaces In the next PR we will move `endpointer_utils`, `authzcheck` and `ratelimit` into `platform` directory. # Checklist for submitter - [x] Added changes file ## Testing - [x] Added/updated tests - [x] QA'd all new/changed functionality manually <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Restructured internal error handling and context management to support bounded context architecture. * Improved error context collection and telemetry observability through a provider-based mechanism. * Decoupled licensing and authentication concerns into interfaces for better modularity. * **Chores** * Updated internal package dependencies to align with new architectural boundaries. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
360a426224
commit
c88cc953fb
47 changed files with 1126 additions and 637 deletions
1
changes/37192-refactor-endpoint_utils
Normal file
1
changes/37192-refactor-endpoint_utils
Normal file
|
|
@ -0,0 +1 @@
|
|||
Refactored common endpoint_utils package to support bounded contexts inside Fleet codebase.
|
||||
|
|
@ -30,8 +30,11 @@ type PackageTest struct {
|
|||
forbiddenPkgs []string
|
||||
}
|
||||
|
||||
// ModuleName is the module name for the fleet project.
|
||||
const ModuleName = "github.com/fleetdm/fleet/v4"
|
||||
|
||||
// PackageTest will ignore dependency on this package.
|
||||
const thisPackage = "github.com/fleetdm/fleet/v4/server/archtest"
|
||||
const thisPackage = ModuleName + "/server/archtest"
|
||||
|
||||
type TestingT interface {
|
||||
Errorf(format string, args ...any)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -66,27 +67,15 @@ func (e *Forbidden) LogFields() []interface{} {
|
|||
|
||||
// CheckMissing is the error to return when no authorization check was performed
|
||||
// by the service.
|
||||
type CheckMissing struct {
|
||||
response interface{}
|
||||
|
||||
fleet.ErrorWithUUID
|
||||
}
|
||||
//
|
||||
// Deprecated: Use platform_http.CheckMissing instead. This alias is kept for
|
||||
// backward compatibility.
|
||||
type CheckMissing = platform_http.CheckMissing
|
||||
|
||||
// CheckMissingWithResponse creates a new error indicating the authorization
|
||||
// check was missed, and including the response for further analysis by the error
|
||||
// encoder.
|
||||
func CheckMissingWithResponse(response interface{}) *CheckMissing {
|
||||
return &CheckMissing{response: response}
|
||||
}
|
||||
|
||||
func (e *CheckMissing) Error() string {
|
||||
return ForbiddenErrorMessage
|
||||
}
|
||||
|
||||
func (e *CheckMissing) Internal() string {
|
||||
return "Missing authorization check"
|
||||
}
|
||||
|
||||
func (e *CheckMissing) Response() interface{} {
|
||||
return e.response
|
||||
}
|
||||
//
|
||||
// Deprecated: Use platform_http.CheckMissingWithResponse instead. This alias is
|
||||
// kept for backward compatibility.
|
||||
var CheckMissingWithResponse = platform_http.CheckMissingWithResponse
|
||||
|
|
|
|||
|
|
@ -16,13 +16,12 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/host"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"go.elastic.co/apm/v2"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
|
|
@ -121,21 +120,9 @@ func setMetadata(ctx context.Context, data map[string]interface{}) map[string]in
|
|||
|
||||
data["timestamp"] = nowFn().Format(time.RFC3339)
|
||||
|
||||
if h, ok := host.FromContext(ctx); ok {
|
||||
data["host"] = map[string]interface{}{
|
||||
"platform": h.Platform,
|
||||
"osquery_version": h.OsqueryVersion,
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := viewer.FromContext(ctx); ok {
|
||||
vdata := map[string]interface{}{}
|
||||
data["viewer"] = vdata
|
||||
vdata["is_logged_in"] = v.IsLoggedIn()
|
||||
|
||||
if v.User != nil {
|
||||
vdata["sso_enabled"] = v.User.SSOEnabled
|
||||
}
|
||||
// Get diagnostic context from all registered providers
|
||||
for _, provider := range getErrorContextProviders(ctx) {
|
||||
maps.Copy(data, provider.GetDiagnosticContext())
|
||||
}
|
||||
|
||||
return data
|
||||
|
|
@ -215,7 +202,7 @@ func Wrapf(ctx context.Context, cause error, format string, args ...interface{})
|
|||
|
||||
// Cause returns the root error in err's chain.
|
||||
func Cause(err error) error {
|
||||
return fleet.Cause(err)
|
||||
return platform_http.Cause(err)
|
||||
}
|
||||
|
||||
// FleetCause is similar to Cause, but returns the root-most
|
||||
|
|
@ -319,6 +306,9 @@ func Handle(ctx context.Context, err error) {
|
|||
cause = rootCause
|
||||
}
|
||||
|
||||
// Collect telemetry context from registered providers
|
||||
telemetryAttrs := collectTelemetryContext(ctx)
|
||||
|
||||
// send to OpenTelemetry if there's an active span
|
||||
if span := trace.SpanFromContext(ctx); span != nil && span.IsRecording() {
|
||||
// Mark the current span as failed by setting the error status.
|
||||
|
|
@ -333,20 +323,25 @@ func Handle(ctx context.Context, err error) {
|
|||
attribute.String("exception.stacktrace", strings.Join(cause.Stack(), "\n")),
|
||||
}
|
||||
|
||||
// Add contextual information if available (same as Sentry)
|
||||
v, _ := viewer.FromContext(ctx)
|
||||
h, _ := host.FromContext(ctx)
|
||||
|
||||
if v.User != nil {
|
||||
attrs = append(attrs,
|
||||
// Not sending the email here as it may contain sensitive information (PII).
|
||||
attribute.Int64("user.id", int64(v.User.ID)), //nolint:gosec
|
||||
)
|
||||
} else if h != nil {
|
||||
attrs = append(attrs,
|
||||
attribute.String("host.hostname", h.Hostname),
|
||||
attribute.Int64("host.id", int64(h.ID)), //nolint:gosec
|
||||
)
|
||||
// Add contextual information from telemetry providers.
|
||||
// OpenTelemetry requires typed attributes, so we convert the values to the appropriate type.
|
||||
for k, v := range telemetryAttrs {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
attrs = append(attrs, attribute.String(k, val))
|
||||
case int:
|
||||
attrs = append(attrs, attribute.Int64(k, int64(val)))
|
||||
case int64:
|
||||
attrs = append(attrs, attribute.Int64(k, val))
|
||||
case uint:
|
||||
attrs = append(attrs, attribute.Int64(k, int64(val))) //nolint:gosec
|
||||
case uint64:
|
||||
attrs = append(attrs, attribute.Int64(k, int64(val))) //nolint:gosec
|
||||
case bool:
|
||||
attrs = append(attrs, attribute.Bool(k, val))
|
||||
default:
|
||||
attrs = append(attrs, attribute.String(k, fmt.Sprint(val)))
|
||||
}
|
||||
}
|
||||
|
||||
span.AddEvent("exception", trace.WithAttributes(attrs...))
|
||||
|
|
@ -357,25 +352,14 @@ func Handle(ctx context.Context, err error) {
|
|||
|
||||
// if Sentry is configured, capture the error there
|
||||
if sentryClient := sentry.CurrentHub().Client(); sentryClient != nil {
|
||||
// sentry is configured, add contextual information if available
|
||||
v, _ := viewer.FromContext(ctx)
|
||||
h, _ := host.FromContext(ctx)
|
||||
|
||||
if v.User != nil || h != nil {
|
||||
// we have a viewer (user) or a host in the context, use this to
|
||||
// enrich the error with more context
|
||||
if len(telemetryAttrs) > 0 {
|
||||
// we have contextual information, use it to enrich the error
|
||||
ctxHub := sentry.CurrentHub().Clone()
|
||||
if v.User != nil {
|
||||
ctxHub.ConfigureScope(func(scope *sentry.Scope) {
|
||||
scope.SetTag("email", v.User.Email)
|
||||
scope.SetTag("user_id", fmt.Sprint(v.User.ID))
|
||||
})
|
||||
} else if h != nil {
|
||||
ctxHub.ConfigureScope(func(scope *sentry.Scope) {
|
||||
scope.SetTag("hostname", h.Hostname)
|
||||
scope.SetTag("host_id", fmt.Sprint(h.ID))
|
||||
})
|
||||
}
|
||||
ctxHub.ConfigureScope(func(scope *sentry.Scope) {
|
||||
for k, v := range telemetryAttrs {
|
||||
scope.SetTag(k, fmt.Sprint(v))
|
||||
}
|
||||
})
|
||||
ctxHub.CaptureException(cause)
|
||||
} else {
|
||||
sentry.CaptureException(cause)
|
||||
|
|
@ -387,6 +371,17 @@ func Handle(ctx context.Context, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
// collectTelemetryContext gathers telemetry context from all registered providers.
|
||||
func collectTelemetryContext(ctx context.Context) map[string]any {
|
||||
attrs := make(map[string]any)
|
||||
for _, provider := range getErrorContextProviders(ctx) {
|
||||
if telemetry := provider.GetTelemetryContext(); telemetry != nil {
|
||||
maps.Copy(attrs, telemetry)
|
||||
}
|
||||
}
|
||||
return attrs
|
||||
}
|
||||
|
||||
// Retrieve retrieves an error from the registered error handler
|
||||
func Retrieve(ctx context.Context) ([]*StoredError, error) {
|
||||
eh := FromContext(ctx)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,11 @@ func TestHandleSendsContextToOTEL(t *testing.T) {
|
|||
ID: 123,
|
||||
Email: "test@example.com",
|
||||
}
|
||||
return viewer.NewContext(ctx, viewer.Viewer{User: testUser})
|
||||
v := viewer.Viewer{User: testUser}
|
||||
ctx = viewer.NewContext(ctx, v)
|
||||
// Register the viewer as an error context provider
|
||||
ctx = AddErrorContextProvider(ctx, &v)
|
||||
return ctx
|
||||
},
|
||||
errorMessage: "test error with user context",
|
||||
expectedAttrs: map[string]any{
|
||||
|
|
@ -43,7 +47,10 @@ func TestHandleSendsContextToOTEL(t *testing.T) {
|
|||
ID: 456,
|
||||
Hostname: "test-host.example.com",
|
||||
}
|
||||
return host.NewContext(ctx, testHost)
|
||||
ctx = host.NewContext(ctx, testHost)
|
||||
// Register the host as an error context provider
|
||||
ctx = AddErrorContextProvider(ctx, &host.HostAttributeProvider{Host: testHost})
|
||||
return ctx
|
||||
},
|
||||
errorMessage: "test error with host context",
|
||||
expectedAttrs: map[string]any{
|
||||
|
|
|
|||
|
|
@ -322,7 +322,10 @@ func TestAdditionalMetadata(t *testing.T) {
|
|||
t.Run("saves additional data about the host if present", func(t *testing.T) {
|
||||
ctx, cleanup := setup()
|
||||
defer cleanup()
|
||||
hctx := host.NewContext(ctx, &fleet.Host{Platform: "test_platform", OsqueryVersion: "5.0"})
|
||||
h := &fleet.Host{Platform: "test_platform", OsqueryVersion: "5.0"}
|
||||
hctx := host.NewContext(ctx, h)
|
||||
// Register the host as an error context provider
|
||||
hctx = AddErrorContextProvider(hctx, &host.HostAttributeProvider{Host: h})
|
||||
err := New(hctx, "with host context").(*FleetError)
|
||||
|
||||
require.JSONEq(t, string(err.data), `{"host":{"osquery_version":"5.0","platform":"test_platform"},"timestamp":"1969-06-19T21:44:05Z"}`)
|
||||
|
|
@ -331,7 +334,10 @@ func TestAdditionalMetadata(t *testing.T) {
|
|||
t.Run("saves additional data about the viewer if present", func(t *testing.T) {
|
||||
ctx, cleanup := setup()
|
||||
defer cleanup()
|
||||
vctx := viewer.NewContext(ctx, viewer.Viewer{Session: &fleet.Session{ID: 1}, User: &fleet.User{SSOEnabled: true}})
|
||||
v := viewer.Viewer{Session: &fleet.Session{ID: 1}, User: &fleet.User{SSOEnabled: true}}
|
||||
vctx := viewer.NewContext(ctx, v)
|
||||
// Register the viewer as an error context provider
|
||||
vctx = AddErrorContextProvider(vctx, &v)
|
||||
err := New(vctx, "with host context").(*FleetError)
|
||||
|
||||
require.JSONEq(t, string(err.data), `{"viewer":{"is_logged_in":true,"sso_enabled":true},"timestamp":"1969-06-19T21:44:05Z"}`)
|
||||
|
|
|
|||
35
server/contexts/ctxerr/metadata.go
Normal file
35
server/contexts/ctxerr/metadata.go
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
package ctxerr
|
||||
|
||||
import "context"
|
||||
|
||||
// ErrorContextProvider provides contextual information for error handling.
|
||||
// Implementations can provide data for both error storage and telemetry systems.
|
||||
type ErrorContextProvider interface {
|
||||
// GetDiagnosticContext returns attributes stored with errors for troubleshooting.
|
||||
// Data is persisted to Redis and included in logs. Should contain diagnostic
|
||||
// information like platform, versions, and status flags. Avoid including PII.
|
||||
GetDiagnosticContext() map[string]any
|
||||
|
||||
// GetTelemetryContext returns attributes sent to observability systems
|
||||
// (OpenTelemetry, Sentry). May include identifiers not stored with errors.
|
||||
// Return nil if no telemetry context is available.
|
||||
GetTelemetryContext() map[string]any
|
||||
}
|
||||
|
||||
type errorContextProvidersKey struct{}
|
||||
|
||||
// AddErrorContextProvider returns a new context with the given provider added to
|
||||
// the existing providers. This is useful when you want to add a provider
|
||||
// without replacing existing ones.
|
||||
func AddErrorContextProvider(ctx context.Context, provider ErrorContextProvider) context.Context {
|
||||
existing := getErrorContextProviders(ctx)
|
||||
providers := make([]ErrorContextProvider, len(existing)+1)
|
||||
copy(providers, existing)
|
||||
providers[len(existing)] = provider
|
||||
return context.WithValue(ctx, errorContextProvidersKey{}, providers)
|
||||
}
|
||||
|
||||
func getErrorContextProviders(ctx context.Context) []ErrorContextProvider {
|
||||
providers, _ := ctx.Value(errorContextProvidersKey{}).([]ErrorContextProvider)
|
||||
return providers
|
||||
}
|
||||
|
|
@ -3,8 +3,6 @@ package ctxerr
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
)
|
||||
|
||||
type ErrorAgg struct {
|
||||
|
|
@ -63,12 +61,21 @@ out:
|
|||
return stack[:stackIdx]
|
||||
}
|
||||
|
||||
// vitalErrorData represents the structure of vital fleetd error data.
|
||||
type vitalErrorData struct {
|
||||
ErrorSource string `json:"error_source"`
|
||||
ErrorSourceVersion string `json:"error_source_version"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
ErrorAdditionalInfo map[string]any `json:"error_additional_info"`
|
||||
Vital bool `json:"vital"`
|
||||
}
|
||||
|
||||
func getVitalMetadata(chain []fleetErrorJSON) json.RawMessage {
|
||||
for _, e := range chain {
|
||||
if len(e.Data) > 0 {
|
||||
// Currently, only vital fleetd errors contain metadata.
|
||||
// Note: vital errors should not contain any sensitive info
|
||||
var fleetdErr fleet.FleetdError
|
||||
var fleetdErr vitalErrorData
|
||||
var err error
|
||||
if err = json.Unmarshal(e.Data, &fleetdErr); err != nil || !fleetdErr.Vital {
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -22,3 +22,33 @@ func FromContext(ctx context.Context) (*fleet.Host, bool) {
|
|||
host, ok := ctx.Value(hostKey).(*fleet.Host)
|
||||
return host, ok
|
||||
}
|
||||
|
||||
// HostAttributeProvider wraps a fleet.Host to provide error context.
|
||||
// It implements ctxerr.ErrorContextProvider.
|
||||
type HostAttributeProvider struct {
|
||||
Host *fleet.Host
|
||||
}
|
||||
|
||||
// GetDiagnosticContext implements ctxerr.ErrorContextProvider
|
||||
func (p *HostAttributeProvider) GetDiagnosticContext() map[string]any {
|
||||
if p.Host == nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{
|
||||
"host": map[string]any{
|
||||
"platform": p.Host.Platform,
|
||||
"osquery_version": p.Host.OsqueryVersion,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetTelemetryContext implements ctxerr.ErrorContextProvider
|
||||
func (p *HostAttributeProvider) GetTelemetryContext() map[string]any {
|
||||
if p.Host == nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{
|
||||
"host.hostname": p.Host.Hostname,
|
||||
"host.id": p.Host.ID,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,23 +4,34 @@ package license
|
|||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
)
|
||||
|
||||
// LicenseChecker is the interface for checking license properties.
|
||||
// This interface is implemented by fleet.LicenseInfo
|
||||
type LicenseChecker interface {
|
||||
IsPremium() bool
|
||||
IsAllowDisableTelemetry() bool
|
||||
// GetTier returns the license tier (e.g., "free", "premium", "trial").
|
||||
GetTier() string
|
||||
// GetOrganization returns the name of the licensed organization.
|
||||
GetOrganization() string
|
||||
// GetDeviceCount returns the number of licensed devices.
|
||||
GetDeviceCount() int
|
||||
}
|
||||
|
||||
type key int
|
||||
|
||||
const licenseKey key = 0
|
||||
|
||||
// NewContext creates a new context.Context with the license.
|
||||
func NewContext(ctx context.Context, lic *fleet.LicenseInfo) context.Context {
|
||||
func NewContext(ctx context.Context, lic LicenseChecker) context.Context {
|
||||
return context.WithValue(ctx, licenseKey, lic)
|
||||
}
|
||||
|
||||
// FromContext returns the license from the context and true, or nil and false
|
||||
// if there is no license.
|
||||
func FromContext(ctx context.Context) (*fleet.LicenseInfo, bool) {
|
||||
v, ok := ctx.Value(licenseKey).(*fleet.LicenseInfo)
|
||||
// FromContext returns the license from the context as a LicenseChecker interface.
|
||||
// Use this when you only need to check license properties via the interface methods.
|
||||
func FromContext(ctx context.Context) (LicenseChecker, bool) {
|
||||
v, ok := ctx.Value(licenseKey).(LicenseChecker)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
|
|
@ -34,6 +45,8 @@ func IsPremium(ctx context.Context) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// IsAllowDisableTelemetry returns true if telemetry can be disabled based on
|
||||
// the license in the context.
|
||||
func IsAllowDisableTelemetry(ctx context.Context) bool {
|
||||
if lic, ok := FromContext(ctx); ok {
|
||||
return lic.IsAllowDisableTelemetry()
|
||||
|
|
|
|||
|
|
@ -7,13 +7,31 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
||||
kithttp "github.com/go-kit/kit/transport/http"
|
||||
kitlog "github.com/go-kit/log"
|
||||
"github.com/go-kit/log/level"
|
||||
)
|
||||
|
||||
// UserEmailer provides the user's email for logging purposes.
|
||||
type UserEmailer interface {
|
||||
Email() string
|
||||
}
|
||||
|
||||
type userEmailerKey struct{}
|
||||
|
||||
// WithUserEmailer returns a context with the UserEmailer stored for logging.
|
||||
// This should be called by authentication middleware after the user is identified.
|
||||
func WithUserEmailer(ctx context.Context, emailer UserEmailer) context.Context {
|
||||
return context.WithValue(ctx, userEmailerKey{}, emailer)
|
||||
}
|
||||
|
||||
// UserEmailerFromContext retrieves the UserEmailer from the context.
|
||||
func UserEmailerFromContext(ctx context.Context) (UserEmailer, bool) {
|
||||
v, ok := ctx.Value(userEmailerKey{}).(UserEmailer)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
type key int
|
||||
|
||||
const loggingKey key = 0
|
||||
|
|
@ -138,9 +156,8 @@ func (l *LoggingContext) Log(ctx context.Context, logger kitlog.Logger) {
|
|||
|
||||
if !l.SkipUser {
|
||||
loggedInUser := "unauthenticated"
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if ok {
|
||||
loggedInUser = vc.Email()
|
||||
if emailer, ok := UserEmailerFromContext(ctx); ok {
|
||||
loggedInUser = emailer.Email()
|
||||
}
|
||||
keyvals = append(keyvals, "user", loggedInUser)
|
||||
}
|
||||
|
|
@ -168,7 +185,7 @@ func (l *LoggingContext) Log(ctx context.Context, logger kitlog.Logger) {
|
|||
)
|
||||
separator := " || "
|
||||
for _, err := range l.Errs {
|
||||
var ewi fleet.ErrWithInternal
|
||||
var ewi platform_http.ErrWithInternal
|
||||
if errors.As(err, &ewi) {
|
||||
if internalErrs == "" {
|
||||
internalErrs = ewi.Internal()
|
||||
|
|
@ -182,7 +199,7 @@ func (l *LoggingContext) Log(ctx context.Context, logger kitlog.Logger) {
|
|||
errs += separator + err.Error()
|
||||
}
|
||||
}
|
||||
var ewuuid fleet.ErrorUUIDer
|
||||
var ewuuid platform_http.ErrorUUIDer
|
||||
if errors.As(err, &ewuuid) {
|
||||
if uuid := ewuuid.UUID(); uuid != "" {
|
||||
uuids = append(uuids, uuid)
|
||||
|
|
@ -209,7 +226,7 @@ func (l *LoggingContext) setLevelError() bool {
|
|||
}
|
||||
|
||||
if len(l.Errs) == 1 {
|
||||
var ew fleet.ErrWithIsClientError
|
||||
var ew platform_http.ErrWithIsClientError
|
||||
if errors.As(l.Errs[0], &ew) && ew.IsClientError() {
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ package viewer
|
|||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
)
|
||||
|
|
@ -102,3 +103,38 @@ func (v Viewer) CanPerformPasswordReset() bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetDiagnosticContext implements ctxerr.ErrorContextProvider
|
||||
func (v *Viewer) GetDiagnosticContext() map[string]any {
|
||||
vdata := map[string]any{
|
||||
"is_logged_in": v.IsLoggedIn(),
|
||||
}
|
||||
if v.User != nil {
|
||||
vdata["sso_enabled"] = v.User.SSOEnabled
|
||||
}
|
||||
return map[string]any{
|
||||
"viewer": vdata,
|
||||
}
|
||||
}
|
||||
|
||||
// GetTelemetryContext implements ctxerr.ErrorContextProvider
|
||||
func (v *Viewer) GetTelemetryContext() map[string]any {
|
||||
if v.User == nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{
|
||||
"user.id": v.User.ID,
|
||||
"user.email": maskEmail(v.User.Email),
|
||||
}
|
||||
}
|
||||
|
||||
// maskEmail anonymizes an email address for telemetry by showing only
|
||||
// the first character of the local part and the full domain.
|
||||
// Example: "john.doe@example.com" -> "j***@example.com"
|
||||
func maskEmail(email string) string {
|
||||
parts := strings.SplitN(email, "@", 2)
|
||||
if len(parts) != 2 || len(parts[0]) == 0 {
|
||||
return "***"
|
||||
}
|
||||
return string(parts[0][0]) + "***@" + parts[1]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -153,3 +153,26 @@ func TestCanPerformActions(t *testing.T) {
|
|||
// assert.Equal(t, false, needsPasswordResetAdminViewer.CanPerformWriteActionOnUser(1))
|
||||
// assert.Equal(t, true, needsPasswordResetAdminViewer.CanPerformWriteActionOnUser(needsPasswordResetAdminViewer.User.ID))
|
||||
// }
|
||||
|
||||
func TestMaskEmail(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
email string
|
||||
expected string
|
||||
}{
|
||||
{"standard email", "john.doe@example.com", "j***@example.com"},
|
||||
{"single char local", "j@example.com", "j***@example.com"},
|
||||
{"subdomain", "user@mail.example.com", "u***@mail.example.com"},
|
||||
{"empty string", "", "***"},
|
||||
{"no at sign", "invalid", "***"},
|
||||
{"empty local part", "@example.com", "***"},
|
||||
{"only at sign", "@", "***"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := maskEmail(tc.email)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -144,7 +144,7 @@ func (ds *Datastore) ShouldSendStatistics(ctx context.Context, frequency time.Du
|
|||
stats.NumHostsNotResponding = amountHostsNotResponding
|
||||
stats.Organization = "unknown"
|
||||
if lic != nil && lic.IsPremium() {
|
||||
stats.Organization = lic.Organization
|
||||
stats.Organization = lic.GetOrganization()
|
||||
}
|
||||
stats.AIFeaturesDisabled = appConfig.ServerSettings.AIFeaturesDisabled
|
||||
stats.MaintenanceWindowsConfigured = len(appConfig.Integrations.GoogleCalendar) > 0 && appConfig.Integrations.GoogleCalendar[0].Domain != "" && len(appConfig.Integrations.GoogleCalendar[0].ApiKey) > 0
|
||||
|
|
@ -187,7 +187,7 @@ func (ds *Datastore) ShouldSendStatistics(ctx context.Context, frequency time.Du
|
|||
LicenseTier: fleet.TierFree,
|
||||
}
|
||||
if lic != nil {
|
||||
stats.LicenseTier = lic.Tier
|
||||
stats.LicenseTier = lic.GetTier()
|
||||
}
|
||||
if err := computeStats(&stats, time.Now().Add(-frequency)); err != nil {
|
||||
return fleet.StatisticsPayload{}, false, ctxerr.Wrap(ctx, err, "compute statistics")
|
||||
|
|
@ -212,7 +212,7 @@ func (ds *Datastore) ShouldSendStatistics(ctx context.Context, frequency time.Du
|
|||
LicenseTier: fleet.TierFree,
|
||||
}
|
||||
if lic != nil {
|
||||
stats.LicenseTier = lic.Tier
|
||||
stats.LicenseTier = lic.GetTier()
|
||||
}
|
||||
if err := computeStats(&stats, lastUpdated); err != nil {
|
||||
return fleet.StatisticsPayload{}, false, ctxerr.Wrap(ctx, err, "compute statistics")
|
||||
|
|
|
|||
|
|
@ -1476,6 +1476,24 @@ func (l *LicenseInfo) IsAllowDisableTelemetry() bool {
|
|||
return !l.IsPremium() || l.AllowDisableTelemetry
|
||||
}
|
||||
|
||||
// Tier returns the license tier.
|
||||
// This method implements license.LicenseChecker.
|
||||
func (l *LicenseInfo) GetTier() string {
|
||||
return l.Tier
|
||||
}
|
||||
|
||||
// Organization returns the name of the licensed organization.
|
||||
// This method implements license.LicenseChecker.
|
||||
func (l *LicenseInfo) GetOrganization() string {
|
||||
return l.Organization
|
||||
}
|
||||
|
||||
// DeviceCount returns the number of licensed devices.
|
||||
// This method implements license.LicenseChecker.
|
||||
func (l *LicenseInfo) GetDeviceCount() int {
|
||||
return l.DeviceCount
|
||||
}
|
||||
|
||||
const (
|
||||
HeaderLicenseKey = "X-Fleet-License"
|
||||
HeaderLicenseValueExpired = "Expired"
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/mdm/nanodep/godep"
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/storage"
|
||||
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
|
|
@ -2836,19 +2837,11 @@ type AlreadyExistsError interface {
|
|||
IsExists() bool
|
||||
}
|
||||
|
||||
// ForeignKeyError is returned when the operation fails due to foreign key constraints.
|
||||
type ForeignKeyError interface {
|
||||
error
|
||||
IsForeignKey() bool
|
||||
}
|
||||
// ForeignKeyError is an alias for platform_http.ForeignKeyError.
|
||||
type ForeignKeyError = platform_http.ForeignKeyError
|
||||
|
||||
func IsForeignKey(err error) bool {
|
||||
var fke ForeignKeyError
|
||||
if errors.As(err, &fke) {
|
||||
return fke.IsForeignKey()
|
||||
}
|
||||
return false
|
||||
}
|
||||
// IsForeignKey is an alias for platform_http.IsForeignKey.
|
||||
var IsForeignKey = platform_http.IsForeignKey
|
||||
|
||||
type OptionalArg func() interface{}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,22 +1,18 @@
|
|||
package fleet
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoContext = errors.New("context key not set")
|
||||
ErrPasswordResetRequired = &passwordResetRequiredError{}
|
||||
ErrPasswordResetRequired = platform_http.ErrPasswordResetRequired
|
||||
ErrMissingLicense = &licenseError{}
|
||||
ErrMDMNotConfigured = &MDMNotConfiguredError{}
|
||||
ErrWindowsMDMNotConfigured = &WindowsMDMNotConfiguredError{}
|
||||
|
|
@ -46,40 +42,17 @@ type ErrWithStatusCode interface {
|
|||
StatusCode() int
|
||||
}
|
||||
|
||||
// ErrWithInternal is an interface for errors that include extra "internal"
|
||||
// information that should be logged in server logs but not sent to clients.
|
||||
type ErrWithInternal interface {
|
||||
error
|
||||
// Internal returns the error string that must only be logged internally,
|
||||
// not returned to the client.
|
||||
Internal() string
|
||||
}
|
||||
// ErrWithInternal is an alias for platform_http.ErrWithInternal.
|
||||
type ErrWithInternal = platform_http.ErrWithInternal
|
||||
|
||||
// ErrWithLogFields is an interface for errors that include additional logging
|
||||
// fields that should be logged in server logs but not sent to clients.
|
||||
type ErrWithLogFields interface {
|
||||
error
|
||||
// LogFields returns the additional log fields to add, which should come in
|
||||
// key, value pairs (as used in go-kit log).
|
||||
LogFields() []interface{}
|
||||
}
|
||||
// ErrWithLogFields is an alias for platform_http.ErrWithLogFields.
|
||||
type ErrWithLogFields = platform_http.ErrWithLogFields
|
||||
|
||||
// ErrWithRetryAfter is an interface for errors that should set a specific HTTP
|
||||
// Header Retry-After value (see
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After)
|
||||
type ErrWithRetryAfter interface {
|
||||
error
|
||||
// RetryAfter returns the number of seconds to wait before retry.
|
||||
RetryAfter() int
|
||||
}
|
||||
// ErrWithRetryAfter is an alias for platform_http.ErrWithRetryAfter.
|
||||
type ErrWithRetryAfter = platform_http.ErrWithRetryAfter
|
||||
|
||||
// ErrWithIsClientError is an interface for errors that explicitly specify
|
||||
// whether they are client errors or not. By default, errors are treated as
|
||||
// server errors.
|
||||
type ErrWithIsClientError interface {
|
||||
error
|
||||
IsClientError() bool
|
||||
}
|
||||
// ErrWithIsClientError is an alias for platform_http.ErrWithIsClientError.
|
||||
type ErrWithIsClientError = platform_http.ErrWithIsClientError
|
||||
|
||||
type invalidArgWithStatusError struct {
|
||||
InvalidArgumentError
|
||||
|
|
@ -94,30 +67,11 @@ func (e invalidArgWithStatusError) Status() int {
|
|||
return e.code
|
||||
}
|
||||
|
||||
// ErrorUUIDer is the interface for errors that contain a UUID.
|
||||
type ErrorUUIDer interface {
|
||||
// UUID returns the error's UUID.
|
||||
UUID() string
|
||||
}
|
||||
// ErrorUUIDer is an alias for platform_http.ErrorUUIDer.
|
||||
type ErrorUUIDer = platform_http.ErrorUUIDer
|
||||
|
||||
// ErrorWithUUID can be embedded to error types to implement ErrorUUIDer.
|
||||
type ErrorWithUUID struct {
|
||||
uuid string
|
||||
}
|
||||
|
||||
var _ ErrorUUIDer = (*ErrorWithUUID)(nil)
|
||||
|
||||
// UUID implements the ErrorUUIDer interface.
|
||||
func (e *ErrorWithUUID) UUID() string {
|
||||
if e.uuid == "" {
|
||||
uuid, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
e.uuid = uuid.String()
|
||||
}
|
||||
return e.uuid
|
||||
}
|
||||
// ErrorWithUUID is an alias for platform_http.ErrorWithUUID.
|
||||
type ErrorWithUUID = platform_http.ErrorWithUUID
|
||||
|
||||
// InvalidArgumentError is the error returned when invalid data is presented to
|
||||
// a service method. It is a client error.
|
||||
|
|
@ -187,102 +141,26 @@ func (e InvalidArgumentError) Invalid() []map[string]string {
|
|||
return invalid
|
||||
}
|
||||
|
||||
// BadRequestError is an error type that generates a 400 status code.
|
||||
type BadRequestError struct {
|
||||
Message string
|
||||
InternalErr error
|
||||
// BadRequestError is an alias for platform_http.BadRequestError.
|
||||
type BadRequestError = platform_http.BadRequestError
|
||||
|
||||
ErrorWithUUID
|
||||
}
|
||||
// AuthFailedError is an alias for platform_http.AuthFailedError.
|
||||
type AuthFailedError = platform_http.AuthFailedError
|
||||
|
||||
// Error returns the error message.
|
||||
func (e *BadRequestError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
// NewAuthFailedError is an alias for platform_http.NewAuthFailedError.
|
||||
var NewAuthFailedError = platform_http.NewAuthFailedError
|
||||
|
||||
// This implements the interface required by the server/service package logic
|
||||
// to determine the status code to return to the client.
|
||||
func (e *BadRequestError) BadRequestError() []map[string]string {
|
||||
return nil
|
||||
}
|
||||
// AuthRequiredError is an alias for platform_http.AuthRequiredError.
|
||||
type AuthRequiredError = platform_http.AuthRequiredError
|
||||
|
||||
func (e BadRequestError) Internal() string {
|
||||
if e.InternalErr == nil {
|
||||
return ""
|
||||
}
|
||||
return e.InternalErr.Error()
|
||||
}
|
||||
// NewAuthRequiredError is an alias for platform_http.NewAuthRequiredError.
|
||||
var NewAuthRequiredError = platform_http.NewAuthRequiredError
|
||||
|
||||
type AuthFailedError struct {
|
||||
// internal is the reason that should only be logged internally
|
||||
internal string
|
||||
// AuthHeaderRequiredError is an alias for platform_http.AuthHeaderRequiredError.
|
||||
type AuthHeaderRequiredError = platform_http.AuthHeaderRequiredError
|
||||
|
||||
ErrorWithUUID
|
||||
}
|
||||
|
||||
func NewAuthFailedError(internal string) *AuthFailedError {
|
||||
return &AuthFailedError{internal: internal}
|
||||
}
|
||||
|
||||
func (e AuthFailedError) Error() string {
|
||||
return "Authentication failed"
|
||||
}
|
||||
|
||||
func (e AuthFailedError) Internal() string {
|
||||
return e.internal
|
||||
}
|
||||
|
||||
func (e AuthFailedError) StatusCode() int {
|
||||
return http.StatusUnauthorized
|
||||
}
|
||||
|
||||
type AuthRequiredError struct {
|
||||
// internal is the reason that should only be logged internally
|
||||
internal string
|
||||
|
||||
ErrorWithUUID
|
||||
}
|
||||
|
||||
func NewAuthRequiredError(internal string) *AuthRequiredError {
|
||||
return &AuthRequiredError{internal: internal}
|
||||
}
|
||||
|
||||
func (e AuthRequiredError) Error() string {
|
||||
return "Authentication required"
|
||||
}
|
||||
|
||||
func (e AuthRequiredError) Internal() string {
|
||||
return e.internal
|
||||
}
|
||||
|
||||
func (e AuthRequiredError) StatusCode() int {
|
||||
return http.StatusUnauthorized
|
||||
}
|
||||
|
||||
type AuthHeaderRequiredError struct {
|
||||
// internal is the reason that should only be logged internally
|
||||
internal string
|
||||
|
||||
ErrorWithUUID
|
||||
}
|
||||
|
||||
func NewAuthHeaderRequiredError(internal string) *AuthHeaderRequiredError {
|
||||
return &AuthHeaderRequiredError{
|
||||
internal: internal,
|
||||
}
|
||||
}
|
||||
|
||||
func (e AuthHeaderRequiredError) Error() string {
|
||||
return "Authorization header required"
|
||||
}
|
||||
|
||||
func (e AuthHeaderRequiredError) Internal() string {
|
||||
return e.internal
|
||||
}
|
||||
|
||||
func (e AuthHeaderRequiredError) StatusCode() int {
|
||||
return http.StatusUnauthorized
|
||||
}
|
||||
// NewAuthHeaderRequiredError is an alias for platform_http.NewAuthHeaderRequiredError.
|
||||
var NewAuthHeaderRequiredError = platform_http.NewAuthHeaderRequiredError
|
||||
|
||||
// PermissionError, set when user is authenticated, but not allowed to perform action
|
||||
type PermissionError struct {
|
||||
|
|
@ -347,18 +225,6 @@ func (e licenseError) StatusCode() int {
|
|||
return http.StatusPaymentRequired
|
||||
}
|
||||
|
||||
type passwordResetRequiredError struct {
|
||||
ErrorWithUUID
|
||||
}
|
||||
|
||||
func (e passwordResetRequiredError) Error() string {
|
||||
return "password reset required"
|
||||
}
|
||||
|
||||
func (e passwordResetRequiredError) StatusCode() int {
|
||||
return http.StatusUnauthorized
|
||||
}
|
||||
|
||||
// MDMNotConfiguredError is used when an MDM endpoint or resource is accessed
|
||||
// without having MDM correctly configured.
|
||||
type MDMNotConfiguredError struct{}
|
||||
|
|
@ -453,15 +319,9 @@ func (e *GatewayError) Error() string {
|
|||
return msg
|
||||
}
|
||||
|
||||
// Error is a user facing error (API user). It's meant to be used for errors that are
|
||||
// related to fleet logic specifically. Other errors, such as mysql errors, shouldn't
|
||||
// be translated to this.
|
||||
type Error struct {
|
||||
Code int `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
|
||||
ErrorWithUUID
|
||||
}
|
||||
// Error is an alias for platform_http.Error.
|
||||
// It's meant to be used for errors that are related to fleet logic specifically.
|
||||
type Error = platform_http.Error
|
||||
|
||||
const (
|
||||
// ErrNoRoleNeeded is the error number for valid role needed
|
||||
|
|
@ -491,101 +351,26 @@ func NewErrorf(code int, format string, args ...interface{}) error {
|
|||
}
|
||||
}
|
||||
|
||||
func (ge *Error) Error() string {
|
||||
return ge.Message
|
||||
}
|
||||
// UserMessageError is an alias for platform_http.UserMessageError.
|
||||
type UserMessageError = platform_http.UserMessageError
|
||||
|
||||
// UserMessageError is an error that adds the UserMessage interface
|
||||
// implementation.
|
||||
type UserMessageError struct {
|
||||
error
|
||||
statusCode int
|
||||
|
||||
ErrorWithUUID
|
||||
}
|
||||
|
||||
// NewUserMessageError creates a UserMessageError that will translate the
|
||||
// error message of err to a user-friendly form. If statusCode is > 0, it
|
||||
// will be used as the HTTP status code for the error, otherwise it defaults
|
||||
// to http.StatusUnprocessableEntity (422).
|
||||
func NewUserMessageError(err error, statusCode int) *UserMessageError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserMessageError{
|
||||
error: err,
|
||||
statusCode: statusCode,
|
||||
}
|
||||
}
|
||||
|
||||
var rxJSONUnknownField = regexp.MustCompile(`^json: unknown field "(.+)"$`)
|
||||
// NewUserMessageError is an alias for platform_http.NewUserMessageError.
|
||||
var NewUserMessageError = platform_http.NewUserMessageError
|
||||
|
||||
// IsJSONUnknownFieldError returns true if err is a JSON unknown field error.
|
||||
// There is no exported type or value for this error, so we have to match the
|
||||
// error message.
|
||||
func IsJSONUnknownFieldError(err error) bool {
|
||||
return rxJSONUnknownField.MatchString(err.Error())
|
||||
return platform_http.IsJSONUnknownFieldError(err)
|
||||
}
|
||||
|
||||
// GetJSONUnknownField returns the unknown field name from a JSON unknown field error.
|
||||
func GetJSONUnknownField(err error) *string {
|
||||
errCause := Cause(err)
|
||||
if IsJSONUnknownFieldError(errCause) {
|
||||
substr := rxJSONUnknownField.FindStringSubmatch(errCause.Error())
|
||||
return &substr[1]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserMessage implements the user-friendly translation of the error if its
|
||||
// root cause is one of the supported types, otherwise it returns the error
|
||||
// message.
|
||||
func (e UserMessageError) UserMessage() string {
|
||||
cause := Cause(e.error)
|
||||
switch cause := cause.(type) {
|
||||
case *json.UnmarshalTypeError:
|
||||
var sb strings.Builder
|
||||
curType := cause.Type
|
||||
for curType.Kind() == reflect.Slice || curType.Kind() == reflect.Array {
|
||||
sb.WriteString("array of ")
|
||||
curType = curType.Elem()
|
||||
}
|
||||
sb.WriteString(curType.Name())
|
||||
if curType != cause.Type {
|
||||
// it was an array
|
||||
sb.WriteString("s")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("invalid value type at '%s': expected %s but got %s", cause.Field, sb.String(), cause.Value)
|
||||
|
||||
default:
|
||||
// there's no specific error type for the strict json mode
|
||||
// (DisallowUnknownFields), so resort to message-matching.
|
||||
if matches := rxJSONUnknownField.FindStringSubmatch(cause.Error()); matches != nil {
|
||||
return fmt.Sprintf("unsupported key provided: %q", matches[1])
|
||||
}
|
||||
return e.Error()
|
||||
}
|
||||
}
|
||||
|
||||
// StatusCode implements the kithttp.StatusCoder interface to return the status
|
||||
// code to use in HTTP API responses.
|
||||
func (e UserMessageError) StatusCode() int {
|
||||
if e.statusCode > 0 {
|
||||
return e.statusCode
|
||||
}
|
||||
return http.StatusUnprocessableEntity
|
||||
return platform_http.GetJSONUnknownField(err)
|
||||
}
|
||||
|
||||
// Cause returns the root error in err's chain.
|
||||
func Cause(err error) error {
|
||||
for {
|
||||
uerr := errors.Unwrap(err)
|
||||
if uerr == nil {
|
||||
return err
|
||||
}
|
||||
err = uerr
|
||||
}
|
||||
}
|
||||
var Cause = platform_http.Cause
|
||||
|
||||
// FleetdError is an error that can be reported by any of the fleetd
|
||||
// components.
|
||||
|
|
|
|||
|
|
@ -7,32 +7,34 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/archtest"
|
||||
)
|
||||
|
||||
const m = archtest.ModuleName
|
||||
|
||||
// TestAllAndroidPackageDependencies checks that android packages are not dependent on other Fleet packages
|
||||
// to maintain decoupling and modularity.
|
||||
// If coupling is necessary, it should be done in the main server/fleet, server/service, or other package.
|
||||
func TestAllAndroidPackageDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
archtest.NewPackageTest(t, "github.com/fleetdm/fleet/v4/server/mdm/android...").
|
||||
archtest.NewPackageTest(t, m+"/server/mdm/android...").
|
||||
OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)).
|
||||
ShouldNotDependOn(
|
||||
"github.com/fleetdm/fleet/v4/server/service...",
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql...",
|
||||
m+"/server/service...",
|
||||
m+"/server/datastore/mysql...",
|
||||
).
|
||||
IgnoreRecursively(
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/android/tests",
|
||||
m+"/server/mdm/android/tests",
|
||||
).
|
||||
IgnoreDeps(
|
||||
// Android packages
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/android...",
|
||||
m+"/server/mdm/android...",
|
||||
// Other/infra packages
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql",
|
||||
"github.com/fleetdm/fleet/v4/server/service/externalsvc", // dependency on Jira and Zendesk
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/auth",
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/authzcheck",
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils",
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/log",
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/ratelimit",
|
||||
"github.com/fleetdm/fleet/v4/server/service/modules/activities",
|
||||
m+"/server/datastore/mysql/common_mysql",
|
||||
m+"/server/service/externalsvc", // dependency on Jira and Zendesk
|
||||
m+"/server/service/middleware/auth",
|
||||
m+"/server/service/middleware/authzcheck",
|
||||
m+"/server/service/middleware/endpoint_utils",
|
||||
m+"/server/service/middleware/log",
|
||||
m+"/server/service/middleware/ratelimit",
|
||||
m+"/server/service/modules/activities",
|
||||
).
|
||||
Check()
|
||||
}
|
||||
|
|
@ -42,7 +44,8 @@ func TestAllAndroidPackageDependencies(t *testing.T) {
|
|||
// If coupling is necessary, it should be done in the main server/fleet or another package.
|
||||
func TestAndroidPackageDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
archtest.NewPackageTest(t, "github.com/fleetdm/fleet/v4/server/mdm/android").
|
||||
archtest.NewPackageTest(t, m+"/server/mdm/android").
|
||||
OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)).
|
||||
ShouldNotDependOn("github.com/fleetdm/fleet/v4/...")
|
||||
ShouldNotDependOn(m + "/...").
|
||||
Check()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/android"
|
||||
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/auth"
|
||||
eu "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
|
||||
"github.com/go-json-experiment/json"
|
||||
|
|
@ -21,13 +22,14 @@ func encodeResponse(ctx context.Context, w http.ResponseWriter, response interfa
|
|||
func(w http.ResponseWriter, response interface{}) error {
|
||||
return json.MarshalWrite(w, response, jsontext.WithIndent(" "))
|
||||
},
|
||||
nil, // no domain-specific error encoder
|
||||
)
|
||||
}
|
||||
|
||||
func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
|
||||
return eu.MakeDecoder(iface, func(body io.Reader, req any) error {
|
||||
return json.UnmarshalRead(body, req)
|
||||
}, nil, nil, nil)
|
||||
}, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// handlerFunc is the handler function type for Android service endpoints.
|
||||
|
|
@ -40,12 +42,12 @@ type endpointer struct {
|
|||
svc android.Service
|
||||
}
|
||||
|
||||
func (e *endpointer) CallHandlerFunc(f handlerFunc, ctx context.Context, request interface{},
|
||||
svc interface{}) (fleet.Errorer, error) {
|
||||
func (e *endpointer) CallHandlerFunc(f handlerFunc, ctx context.Context, request any,
|
||||
svc any) (platform_http.Errorer, error) {
|
||||
return f(ctx, request, svc.(android.Service)), nil
|
||||
}
|
||||
|
||||
func (e *endpointer) Service() interface{} {
|
||||
func (e *endpointer) Service() any {
|
||||
return e.svc
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -207,13 +207,18 @@ func (m *mockService) NewActivity(ctx context.Context, user *fleet.User, details
|
|||
}
|
||||
|
||||
func runServerForTests(t *testing.T, logger kitlog.Logger, fleetSvc fleet.Service, androidSvc android.Service) *httptest.Server {
|
||||
// androidErrorEncoder wraps EncodeError with nil domain encoder for android tests
|
||||
androidErrorEncoder := func(ctx context.Context, err error, w http.ResponseWriter) {
|
||||
endpoint_utils.EncodeError(ctx, err, w, nil)
|
||||
}
|
||||
|
||||
fleetAPIOptions := []kithttp.ServerOption{
|
||||
kithttp.ServerBefore(
|
||||
kithttp.PopulateRequestContext,
|
||||
auth.SetRequestsContexts(fleetSvc),
|
||||
),
|
||||
kithttp.ServerErrorHandler(&endpoint_utils.ErrorHandler{Logger: logger}),
|
||||
kithttp.ServerErrorEncoder(endpoint_utils.EncodeError),
|
||||
kithttp.ServerErrorEncoder(androidErrorEncoder),
|
||||
kithttp.ServerAfter(
|
||||
kithttp.SetContentType("application/json; charset=utf-8"),
|
||||
log.LogRequestEnd(logger),
|
||||
|
|
|
|||
44
server/platform/arch_test.go
Normal file
44
server/platform/arch_test.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
package platform_test
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/archtest"
|
||||
)
|
||||
|
||||
const m = archtest.ModuleName
|
||||
|
||||
// TestEndpointerPackageDependencies checks that endpointer package is not dependent on other Fleet domain packages
|
||||
// to maintain decoupling and modularity.
|
||||
func TestEndpointerPackageDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
archtest.NewPackageTest(t, m+"/server/service/middleware/endpoint_utils").
|
||||
OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)).
|
||||
WithTests().
|
||||
ShouldNotDependOn(m+"/...").
|
||||
IgnoreDeps(
|
||||
// Platform packages
|
||||
m+"/server/platform...",
|
||||
// Other infra packages
|
||||
m+"/server/contexts/authz",
|
||||
m+"/server/contexts/ctxerr",
|
||||
m+"/server/contexts/license",
|
||||
m+"/server/contexts/logging",
|
||||
m+"/server/contexts/publicip",
|
||||
m+"/server/service/middleware/authzcheck",
|
||||
m+"/server/service/middleware/ratelimit",
|
||||
).
|
||||
Check()
|
||||
}
|
||||
|
||||
// TestPlatformPackageDependencies checks that platform packages are NOT dependent on ANY other Fleet packages
|
||||
// to maintain decoupling and modularity.
|
||||
func TestPlatformPackageDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
archtest.NewPackageTest(t, m+"/server/platform...").
|
||||
OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)).
|
||||
WithTests().
|
||||
ShouldNotDependOn(m + "/...").
|
||||
Check()
|
||||
}
|
||||
357
server/platform/http/errors.go
Normal file
357
server/platform/http/errors.go
Normal file
|
|
@ -0,0 +1,357 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ErrWithInternal defines an interface for errors that have an internal message
|
||||
// that should only be logged, not returned to the client.
|
||||
type ErrWithInternal interface {
|
||||
error
|
||||
// Internal returns the error string that must only be logged internally,
|
||||
// not returned to the client.
|
||||
Internal() string
|
||||
}
|
||||
|
||||
// ErrWithLogFields defines an interface for errors that have additional log fields.
|
||||
type ErrWithLogFields interface {
|
||||
error
|
||||
// LogFields returns the additional log fields to add, which should come in
|
||||
// key, value pairs (as used in go-kit log).
|
||||
LogFields() []any
|
||||
}
|
||||
|
||||
// ErrorUUIDer defines an interface for errors that have a UUID for tracking.
|
||||
type ErrorUUIDer interface {
|
||||
// UUID returns the error's UUID.
|
||||
UUID() string
|
||||
}
|
||||
|
||||
// ErrorWithUUID can be embedded in error types to implement ErrorUUIDer.
|
||||
type ErrorWithUUID struct {
|
||||
uuid string
|
||||
}
|
||||
|
||||
var _ ErrorUUIDer = (*ErrorWithUUID)(nil)
|
||||
|
||||
// UUID implements the ErrorUUIDer interface.
|
||||
func (e *ErrorWithUUID) UUID() string {
|
||||
if e.uuid == "" {
|
||||
u, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
e.uuid = u.String()
|
||||
}
|
||||
return e.uuid
|
||||
}
|
||||
|
||||
// BadRequestError is the error returned when the request is invalid.
|
||||
type BadRequestError struct {
|
||||
Message string
|
||||
InternalErr error
|
||||
|
||||
ErrorWithUUID
|
||||
}
|
||||
|
||||
// Error returns the error message.
|
||||
func (e *BadRequestError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// BadRequestError implements the interface required by the server/service package logic
|
||||
// to determine the status code to return to the client.
|
||||
func (e *BadRequestError) BadRequestError() []map[string]string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Internal implements the ErrWithInternal interface.
|
||||
func (e BadRequestError) Internal() string {
|
||||
if e.InternalErr != nil {
|
||||
return e.InternalErr.Error()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// UserMessageError is an error that wraps another error with a user-friendly message.
|
||||
type UserMessageError struct {
|
||||
error
|
||||
statusCode int
|
||||
|
||||
ErrorWithUUID
|
||||
}
|
||||
|
||||
// NewUserMessageError creates a UserMessageError that will translate the
|
||||
// error message of err to a user-friendly form. If statusCode is > 0, it
|
||||
// will be used as the HTTP status code for the error, otherwise it defaults
|
||||
// to http.StatusUnprocessableEntity (422).
|
||||
func NewUserMessageError(err error, statusCode int) *UserMessageError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserMessageError{
|
||||
error: err,
|
||||
statusCode: statusCode,
|
||||
}
|
||||
}
|
||||
|
||||
// StatusCode returns the HTTP status code for this error.
|
||||
func (e UserMessageError) StatusCode() int {
|
||||
if e.statusCode > 0 {
|
||||
return e.statusCode
|
||||
}
|
||||
return http.StatusUnprocessableEntity
|
||||
}
|
||||
|
||||
var rxJSONUnknownField = regexp.MustCompile(`^json: unknown field "(.+)"$`)
|
||||
|
||||
// IsJSONUnknownFieldError returns true if err is a JSON unknown field error.
|
||||
// There is no exported type or value for this error, so we have to match the
|
||||
// error message.
|
||||
func IsJSONUnknownFieldError(err error) bool {
|
||||
return rxJSONUnknownField.MatchString(err.Error())
|
||||
}
|
||||
|
||||
// GetJSONUnknownField returns the unknown field name from a JSON unknown field error.
|
||||
func GetJSONUnknownField(err error) *string {
|
||||
errCause := Cause(err)
|
||||
if IsJSONUnknownFieldError(errCause) {
|
||||
substr := rxJSONUnknownField.FindStringSubmatch(errCause.Error())
|
||||
return &substr[1]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserMessage implements the user-friendly translation of the error if its
|
||||
// root cause is one of the supported types, otherwise it returns the error
|
||||
// message.
|
||||
func (e UserMessageError) UserMessage() string {
|
||||
cause := Cause(e.error)
|
||||
switch cause := cause.(type) {
|
||||
case *json.UnmarshalTypeError:
|
||||
var sb strings.Builder
|
||||
curType := cause.Type
|
||||
for curType.Kind() == reflect.Slice || curType.Kind() == reflect.Array {
|
||||
sb.WriteString("array of ")
|
||||
curType = curType.Elem()
|
||||
}
|
||||
sb.WriteString(curType.Name())
|
||||
if curType != cause.Type {
|
||||
// it was an array
|
||||
sb.WriteString("s")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("invalid value type at '%s': expected %s but got %s", cause.Field, sb.String(), cause.Value)
|
||||
|
||||
default:
|
||||
// there's no specific error type for the strict json mode
|
||||
// (DisallowUnknownFields), so resort to message-matching.
|
||||
if matches := rxJSONUnknownField.FindStringSubmatch(cause.Error()); matches != nil {
|
||||
return fmt.Sprintf("unsupported key provided: %q", matches[1])
|
||||
}
|
||||
return e.Error()
|
||||
}
|
||||
}
|
||||
|
||||
// Cause returns the root error in err's chain.
|
||||
func Cause(err error) error {
|
||||
for {
|
||||
uerr := errors.Unwrap(err)
|
||||
if uerr == nil {
|
||||
return err
|
||||
}
|
||||
err = uerr
|
||||
}
|
||||
}
|
||||
|
||||
// ErrWithRetryAfter is an interface for errors that should set a specific HTTP
|
||||
// Header Retry-After value (see
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After)
|
||||
type ErrWithRetryAfter interface {
|
||||
error
|
||||
// RetryAfter returns the number of seconds to wait before retry.
|
||||
RetryAfter() int
|
||||
}
|
||||
|
||||
// ForeignKeyError is an interface for errors caused by foreign key constraint violations.
|
||||
type ForeignKeyError interface {
|
||||
error
|
||||
IsForeignKey() bool
|
||||
}
|
||||
|
||||
// IsForeignKey returns true if err is a foreign key constraint violation.
|
||||
func IsForeignKey(err error) bool {
|
||||
var fke ForeignKeyError
|
||||
if errors.As(err, &fke) {
|
||||
return fke.IsForeignKey()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Error is a generic error type with a code and message.
|
||||
type Error struct {
|
||||
Code int `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
|
||||
ErrorWithUUID
|
||||
}
|
||||
|
||||
// Error returns the error message.
|
||||
func (e *Error) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// ErrWithIsClientError is an interface for errors that explicitly specify
|
||||
// whether they are client errors or not. By default, errors are treated as
|
||||
// server errors.
|
||||
type ErrWithIsClientError interface {
|
||||
error
|
||||
IsClientError() bool
|
||||
}
|
||||
|
||||
// AuthFailedError is returned when authentication fails.
|
||||
type AuthFailedError struct {
|
||||
// internal is the reason that should only be logged internally
|
||||
internal string
|
||||
|
||||
ErrorWithUUID
|
||||
}
|
||||
|
||||
// NewAuthFailedError creates a new AuthFailedError.
|
||||
func NewAuthFailedError(internal string) *AuthFailedError {
|
||||
return &AuthFailedError{internal: internal}
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e AuthFailedError) Error() string {
|
||||
return "Authentication failed"
|
||||
}
|
||||
|
||||
// Internal implements ErrWithInternal.
|
||||
func (e AuthFailedError) Internal() string {
|
||||
return e.internal
|
||||
}
|
||||
|
||||
// StatusCode implements kithttp.StatusCoder.
|
||||
func (e AuthFailedError) StatusCode() int {
|
||||
return http.StatusUnauthorized
|
||||
}
|
||||
|
||||
// AuthRequiredError is returned when authentication is required.
|
||||
type AuthRequiredError struct {
|
||||
// internal is the reason that should only be logged internally
|
||||
internal string
|
||||
|
||||
ErrorWithUUID
|
||||
}
|
||||
|
||||
// NewAuthRequiredError creates a new AuthRequiredError.
|
||||
func NewAuthRequiredError(internal string) *AuthRequiredError {
|
||||
return &AuthRequiredError{internal: internal}
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e AuthRequiredError) Error() string {
|
||||
return "Authentication required"
|
||||
}
|
||||
|
||||
// Internal implements ErrWithInternal.
|
||||
func (e AuthRequiredError) Internal() string {
|
||||
return e.internal
|
||||
}
|
||||
|
||||
// StatusCode implements kithttp.StatusCoder.
|
||||
func (e AuthRequiredError) StatusCode() int {
|
||||
return http.StatusUnauthorized
|
||||
}
|
||||
|
||||
// AuthHeaderRequiredError is returned when an authorization header is required.
|
||||
type AuthHeaderRequiredError struct {
|
||||
// internal is the reason that should only be logged internally
|
||||
internal string
|
||||
|
||||
ErrorWithUUID
|
||||
}
|
||||
|
||||
// NewAuthHeaderRequiredError creates a new AuthHeaderRequiredError.
|
||||
func NewAuthHeaderRequiredError(internal string) *AuthHeaderRequiredError {
|
||||
return &AuthHeaderRequiredError{
|
||||
internal: internal,
|
||||
}
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e AuthHeaderRequiredError) Error() string {
|
||||
return "Authorization header required"
|
||||
}
|
||||
|
||||
// Internal implements ErrWithInternal.
|
||||
func (e AuthHeaderRequiredError) Internal() string {
|
||||
return e.internal
|
||||
}
|
||||
|
||||
// StatusCode implements kithttp.StatusCoder.
|
||||
func (e AuthHeaderRequiredError) StatusCode() int {
|
||||
return http.StatusUnauthorized
|
||||
}
|
||||
|
||||
// ErrPasswordResetRequired is returned when a password reset is required.
|
||||
var ErrPasswordResetRequired = &passwordResetRequiredError{}
|
||||
|
||||
type passwordResetRequiredError struct {
|
||||
ErrorWithUUID
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e passwordResetRequiredError) Error() string {
|
||||
return "password reset required"
|
||||
}
|
||||
|
||||
// StatusCode implements kithttp.StatusCoder.
|
||||
func (e passwordResetRequiredError) StatusCode() int {
|
||||
return http.StatusUnauthorized
|
||||
}
|
||||
|
||||
// ForbiddenErrorMessage is the error message that should be returned to
|
||||
// clients when an action is forbidden. It is intentionally vague to prevent
|
||||
// disclosing information that a client should not have access to.
|
||||
const ForbiddenErrorMessage = "forbidden"
|
||||
|
||||
// CheckMissing is the error to return when no authorization check was performed
|
||||
// by the service.
|
||||
type CheckMissing struct {
|
||||
response any
|
||||
|
||||
ErrorWithUUID
|
||||
}
|
||||
|
||||
// CheckMissingWithResponse creates a new error indicating the authorization
|
||||
// check was missed, and including the response for further analysis by the error
|
||||
// encoder.
|
||||
func CheckMissingWithResponse(response any) *CheckMissing {
|
||||
return &CheckMissing{response: response}
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *CheckMissing) Error() string {
|
||||
return ForbiddenErrorMessage
|
||||
}
|
||||
|
||||
// Internal implements the ErrWithInternal interface.
|
||||
func (e *CheckMissing) Internal() string {
|
||||
return "Missing authorization check"
|
||||
}
|
||||
|
||||
// Response returns the response that was generated before the authorization
|
||||
// check was found to be missing.
|
||||
func (e *CheckMissing) Response() any {
|
||||
return e.response
|
||||
}
|
||||
7
server/platform/http/response.go
Normal file
7
server/platform/http/response.go
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
// Package http provides HTTP types for bounded contexts.
|
||||
package http
|
||||
|
||||
// Errorer is implemented by response types that may contain errors.
|
||||
type Errorer interface {
|
||||
Error() error
|
||||
}
|
||||
|
|
@ -108,7 +108,7 @@ func getAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet.Se
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
license, err := svc.License(ctx)
|
||||
lic, err := svc.License(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -183,7 +183,7 @@ func getAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet.Se
|
|||
|
||||
transparencyURL := fleet.DefaultTransparencyURL
|
||||
// Fleet Premium license is required for custom transparency url
|
||||
if license.IsPremium() && appConfig.FleetDesktop.TransparencyURL != "" {
|
||||
if lic.IsPremium() && appConfig.FleetDesktop.TransparencyURL != "" {
|
||||
transparencyURL = appConfig.FleetDesktop.TransparencyURL
|
||||
}
|
||||
fleetDesktop := fleet.FleetDesktopSettings{TransparencyURL: transparencyURL}
|
||||
|
|
@ -218,7 +218,7 @@ func getAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet.Se
|
|||
appConfigResponseFields: appConfigResponseFields{
|
||||
UpdateInterval: updateIntervalConfig,
|
||||
Vulnerabilities: vulnConfig,
|
||||
License: license,
|
||||
License: lic,
|
||||
Logging: loggingConfig,
|
||||
Email: emailConfig,
|
||||
SandboxEnabled: svc.SandboxEnabled(),
|
||||
|
|
@ -274,7 +274,8 @@ func modifyAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet
|
|||
}
|
||||
|
||||
// We do not use svc.License(ctx) to allow roles (like GitOps) write but not read access to AppConfig.
|
||||
license, _ := license.FromContext(ctx)
|
||||
licChecker, _ := license.FromContext(ctx)
|
||||
lic, _ := licChecker.(*fleet.LicenseInfo)
|
||||
|
||||
loggingConfig, err := svc.LoggingConfig(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -283,14 +284,14 @@ func modifyAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet
|
|||
response := appConfigResponse{
|
||||
AppConfig: *appConfig,
|
||||
appConfigResponseFields: appConfigResponseFields{
|
||||
License: license,
|
||||
License: lic,
|
||||
Logging: loggingConfig,
|
||||
},
|
||||
}
|
||||
|
||||
response.Obfuscate()
|
||||
|
||||
if (!license.IsPremium()) || response.FleetDesktop.TransparencyURL == "" {
|
||||
if lic == nil || (!lic.IsPremium()) || response.FleetDesktop.TransparencyURL == "" {
|
||||
response.FleetDesktop.TransparencyURL = fleet.DefaultTransparencyURL
|
||||
}
|
||||
|
||||
|
|
@ -319,7 +320,8 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
|
|||
oldAppConfig := appConfig.Copy()
|
||||
|
||||
// We do not use svc.License(ctx) to allow roles (like GitOps) write but not read access to AppConfig.
|
||||
license, _ := license.FromContext(ctx)
|
||||
licChecker, _ := license.FromContext(ctx)
|
||||
lic, _ := licChecker.(*fleet.LicenseInfo)
|
||||
|
||||
var oldSMTPSettings fleet.SMTPSettings
|
||||
if appConfig.SMTPSettings != nil {
|
||||
|
|
@ -359,7 +361,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
|
|||
|
||||
// default transparency URL is https://fleetdm.com/transparency so you are allowed to apply as long as it's not changing
|
||||
if newAppConfig.FleetDesktop.TransparencyURL != "" && newAppConfig.FleetDesktop.TransparencyURL != fleet.DefaultTransparencyURL {
|
||||
if !license.IsPremium() {
|
||||
if !lic.IsPremium() {
|
||||
invalid.Append("transparency_url", ErrMissingLicense.Error())
|
||||
return nil, ctxerr.Wrap(ctx, invalid)
|
||||
}
|
||||
|
|
@ -370,7 +372,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
|
|||
}
|
||||
|
||||
if newAppConfig.SSOSettings != nil {
|
||||
validateSSOSettings(newAppConfig, appConfig, invalid, license)
|
||||
validateSSOSettings(newAppConfig, appConfig, invalid, lic)
|
||||
if invalid.HasErrors() {
|
||||
return nil, ctxerr.Wrap(ctx, invalid)
|
||||
}
|
||||
|
|
@ -437,7 +439,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
|
|||
appConfig.MDM.MacOSSetup.EnableReleaseDeviceManually = oldAppConfig.MDM.MacOSSetup.EnableReleaseDeviceManually
|
||||
}
|
||||
if appConfig.MDM.MacOSSetup.ManualAgentInstall.Valid && appConfig.MDM.MacOSSetup.ManualAgentInstall.Value {
|
||||
if !license.IsPremium() {
|
||||
if !lic.IsPremium() {
|
||||
invalid.Append("macos_setup.manual_agent_install", ErrMissingLicense.Error())
|
||||
return nil, ctxerr.Wrap(ctx, invalid)
|
||||
}
|
||||
|
|
@ -477,7 +479,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
|
|||
if newAppConfig.AgentOptions != nil {
|
||||
// if there were Agent Options in the new app config, then it replaced the
|
||||
// agent options in the resulting app config, so validate those.
|
||||
if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, *appConfig.AgentOptions, license.IsPremium(), 0); err != nil {
|
||||
if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, *appConfig.AgentOptions, lic.IsPremium(), 0); err != nil {
|
||||
err = fleet.SuggestAgentOptionsCorrection(err)
|
||||
err = fleet.NewUserMessageError(err, http.StatusBadRequest)
|
||||
if applyOpts.Force && !applyOpts.DryRun {
|
||||
|
|
@ -490,7 +492,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
|
|||
}
|
||||
|
||||
// If the license is Premium, we should always send usage statisics.
|
||||
if !license.IsAllowDisableTelemetry() {
|
||||
if !lic.IsAllowDisableTelemetry() {
|
||||
appConfig.ServerSettings.EnableAnalytics = true
|
||||
}
|
||||
|
||||
|
|
@ -532,7 +534,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
|
|||
isNonEmpty(newAppConfig.ConditionalAccess.OktaAudienceURI) ||
|
||||
isNonEmpty(newAppConfig.ConditionalAccess.OktaCertificate)
|
||||
|
||||
if oktaFieldsBeingSet && !license.IsPremium() {
|
||||
if oktaFieldsBeingSet && !lic.IsPremium() {
|
||||
invalid.Append("conditional_access", ErrMissingLicense.Error())
|
||||
return nil, ctxerr.Wrap(ctx, invalid)
|
||||
}
|
||||
|
|
@ -626,11 +628,11 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
|
|||
appConfig.Integrations.ConditionalAccessEnabled = newAppConfig.Integrations.ConditionalAccessEnabled
|
||||
}
|
||||
|
||||
if err := svc.validateMDM(ctx, license, &oldAppConfig.MDM, &appConfig.MDM, invalid); err != nil {
|
||||
if err := svc.validateMDM(ctx, lic, &oldAppConfig.MDM, &appConfig.MDM, invalid); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "validating MDM config")
|
||||
}
|
||||
|
||||
abmAssignments, err := svc.validateABMAssignments(ctx, &newAppConfig.MDM, &oldAppConfig.MDM, invalid, license)
|
||||
abmAssignments, err := svc.validateABMAssignments(ctx, &newAppConfig.MDM, &oldAppConfig.MDM, invalid, lic)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "validating ABM token assignments")
|
||||
}
|
||||
|
|
@ -638,7 +640,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
|
|||
var vppAssignments map[uint][]uint
|
||||
vppAssignmentsDefined := newAppConfig.MDM.VolumePurchasingProgram.Set && newAppConfig.MDM.VolumePurchasingProgram.Valid
|
||||
if vppAssignmentsDefined {
|
||||
vppAssignments, err = svc.validateVPPAssignments(ctx, newAppConfig.MDM.VolumePurchasingProgram.Value, invalid, license)
|
||||
vppAssignments, err = svc.validateVPPAssignments(ctx, newAppConfig.MDM.VolumePurchasingProgram.Value, invalid, lic)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "validating VPP token assignments")
|
||||
}
|
||||
|
|
@ -738,7 +740,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
|
|||
|
||||
gitopsModeEnabled, gitopsRepoURL := appConfig.UIGitOpsMode.GitopsModeEnabled, appConfig.UIGitOpsMode.RepositoryURL
|
||||
if gitopsModeEnabled {
|
||||
if !license.IsPremium() {
|
||||
if !lic.IsPremium() {
|
||||
return nil, fleet.NewInvalidArgumentError("UI GitOpsMode: ", ErrMissingLicense.Error())
|
||||
}
|
||||
if gitopsRepoURL == "" {
|
||||
|
|
@ -767,7 +769,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
|
|||
|
||||
}
|
||||
|
||||
if !license.IsPremium() {
|
||||
if !lic.IsPremium() {
|
||||
// reset transparency url to empty for downgraded licenses
|
||||
appConfig.FleetDesktop.TransparencyURL = ""
|
||||
}
|
||||
|
|
@ -909,19 +911,19 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
|
|||
//
|
||||
// Process OS updates config changes for Apple devices.
|
||||
//
|
||||
if err := svc.processAppleOSUpdateSettings(ctx, license, fleet.MacOS,
|
||||
if err := svc.processAppleOSUpdateSettings(ctx, lic, fleet.MacOS,
|
||||
oldAppConfig.MDM.MacOSUpdates,
|
||||
appConfig.MDM.MacOSUpdates,
|
||||
); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "process macOS OS updates config change")
|
||||
}
|
||||
if err := svc.processAppleOSUpdateSettings(ctx, license, fleet.IOS,
|
||||
if err := svc.processAppleOSUpdateSettings(ctx, lic, fleet.IOS,
|
||||
oldAppConfig.MDM.IOSUpdates,
|
||||
appConfig.MDM.IOSUpdates,
|
||||
); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "process iOS OS updates config change")
|
||||
}
|
||||
if err := svc.processAppleOSUpdateSettings(ctx, license, fleet.IPadOS,
|
||||
if err := svc.processAppleOSUpdateSettings(ctx, lic, fleet.IPadOS,
|
||||
oldAppConfig.MDM.IPadOSUpdates,
|
||||
appConfig.MDM.IPadOSUpdates,
|
||||
); err != nil {
|
||||
|
|
@ -1002,7 +1004,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
|
|||
appConfig.MDM.EndUserAuthentication.SSOProviderSettings
|
||||
serverURLChanged := oldAppConfig.ServerSettings.ServerURL != appConfig.ServerSettings.ServerURL
|
||||
appleMDMUrlChanged := oldAppConfig.MDMUrl() != appConfig.MDMUrl()
|
||||
if (mdmEnableEndUserAuthChanged || mdmSSOSettingsChanged || serverURLChanged || appleMDMUrlChanged) && license.IsPremium() {
|
||||
if (mdmEnableEndUserAuthChanged || mdmSSOSettingsChanged || serverURLChanged || appleMDMUrlChanged) && lic.IsPremium() {
|
||||
if err := svc.EnterpriseOverrides.MDMAppleSyncDEPProfiles(ctx); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "sync DEP profiles")
|
||||
}
|
||||
|
|
@ -1100,14 +1102,14 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
|
|||
// processAppleOSUpdateSettings updates the OS updates configuration if the minimum version+deadline are updated.
|
||||
func (svc *Service) processAppleOSUpdateSettings(
|
||||
ctx context.Context,
|
||||
license *fleet.LicenseInfo,
|
||||
lic *fleet.LicenseInfo,
|
||||
appleDevice fleet.AppleDevice,
|
||||
oldOSUpdateSettings fleet.AppleOSUpdateSettings,
|
||||
newOSUpdateSettings fleet.AppleOSUpdateSettings,
|
||||
) error {
|
||||
if oldOSUpdateSettings.MinimumVersion.Value != newOSUpdateSettings.MinimumVersion.Value ||
|
||||
oldOSUpdateSettings.Deadline.Value != newOSUpdateSettings.Deadline.Value {
|
||||
if license.IsPremium() {
|
||||
if lic.IsPremium() {
|
||||
if err := svc.EnterpriseOverrides.MDMAppleEditedAppleOSUpdates(ctx, nil, appleDevice, newOSUpdateSettings); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "update DDM profile after Apple OS updates change")
|
||||
}
|
||||
|
|
@ -1174,33 +1176,33 @@ func (svc *Service) HasCustomSetupAssistantConfigurationWebURL(ctx context.Conte
|
|||
|
||||
func (svc *Service) validateMDM(
|
||||
ctx context.Context,
|
||||
license *fleet.LicenseInfo,
|
||||
lic *fleet.LicenseInfo,
|
||||
oldMdm *fleet.MDM,
|
||||
mdm *fleet.MDM,
|
||||
invalid *fleet.InvalidArgumentError,
|
||||
) error {
|
||||
if mdm.EnableDiskEncryption.Value && !license.IsPremium() {
|
||||
if mdm.EnableDiskEncryption.Value && !lic.IsPremium() {
|
||||
invalid.Append("macos_settings.enable_disk_encryption", ErrMissingLicense.Error())
|
||||
}
|
||||
if mdm.MacOSSetup.MacOSSetupAssistant.Value != "" && oldMdm.MacOSSetup.MacOSSetupAssistant.Value != mdm.MacOSSetup.MacOSSetupAssistant.Value && !license.IsPremium() {
|
||||
if mdm.MacOSSetup.MacOSSetupAssistant.Value != "" && oldMdm.MacOSSetup.MacOSSetupAssistant.Value != mdm.MacOSSetup.MacOSSetupAssistant.Value && !lic.IsPremium() {
|
||||
invalid.Append("macos_setup.macos_setup_assistant", ErrMissingLicense.Error())
|
||||
}
|
||||
if mdm.MacOSSetup.EnableReleaseDeviceManually.Value && oldMdm.MacOSSetup.EnableReleaseDeviceManually.Value != mdm.MacOSSetup.EnableReleaseDeviceManually.Value && !license.IsPremium() {
|
||||
if mdm.MacOSSetup.EnableReleaseDeviceManually.Value && oldMdm.MacOSSetup.EnableReleaseDeviceManually.Value != mdm.MacOSSetup.EnableReleaseDeviceManually.Value && !lic.IsPremium() {
|
||||
invalid.Append("macos_setup.enable_release_device_manually", ErrMissingLicense.Error())
|
||||
}
|
||||
if mdm.MacOSSetup.BootstrapPackage.Value != "" && oldMdm.MacOSSetup.BootstrapPackage.Value != mdm.MacOSSetup.BootstrapPackage.Value && !license.IsPremium() {
|
||||
if mdm.MacOSSetup.BootstrapPackage.Value != "" && oldMdm.MacOSSetup.BootstrapPackage.Value != mdm.MacOSSetup.BootstrapPackage.Value && !lic.IsPremium() {
|
||||
invalid.Append("macos_setup.bootstrap_package", ErrMissingLicense.Error())
|
||||
}
|
||||
if mdm.MacOSSetup.EnableEndUserAuthentication && oldMdm.MacOSSetup.EnableEndUserAuthentication != mdm.MacOSSetup.EnableEndUserAuthentication && !license.IsPremium() {
|
||||
if mdm.MacOSSetup.EnableEndUserAuthentication && oldMdm.MacOSSetup.EnableEndUserAuthentication != mdm.MacOSSetup.EnableEndUserAuthentication && !lic.IsPremium() {
|
||||
invalid.Append("macos_setup.enable_end_user_authentication", ErrMissingLicense.Error())
|
||||
}
|
||||
if mdm.MacOSSetup.ManualAgentInstall.Valid && oldMdm.MacOSSetup.ManualAgentInstall.Value != mdm.MacOSSetup.ManualAgentInstall.Value && !license.IsPremium() {
|
||||
if mdm.MacOSSetup.ManualAgentInstall.Valid && oldMdm.MacOSSetup.ManualAgentInstall.Value != mdm.MacOSSetup.ManualAgentInstall.Value && !lic.IsPremium() {
|
||||
invalid.Append("macos_setup.manual_agent_install", ErrMissingLicense.Error())
|
||||
}
|
||||
if mdm.WindowsMigrationEnabled && !license.IsPremium() {
|
||||
if mdm.WindowsMigrationEnabled && !lic.IsPremium() {
|
||||
invalid.Append("windows_migration_enabled", ErrMissingLicense.Error())
|
||||
}
|
||||
if mdm.EnableTurnOnWindowsMDMManually && !license.IsPremium() {
|
||||
if mdm.EnableTurnOnWindowsMDMManually && !lic.IsPremium() {
|
||||
invalid.Append("enable_turn_on_windows_mdm_manually", ErrMissingLicense.Error())
|
||||
}
|
||||
|
||||
|
|
@ -1298,7 +1300,7 @@ func (svc *Service) validateMDM(
|
|||
updatingIPadOSVersion || updatingIPadOSDeadline {
|
||||
// TODO: Should we validate MDM configured on here too?
|
||||
|
||||
if !license.IsPremium() {
|
||||
if !lic.IsPremium() {
|
||||
invalid.Append("macos_updates.minimum_version", ErrMissingLicense.Error())
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1318,7 +1320,7 @@ func (svc *Service) validateMDM(
|
|||
if updatingWindowsUpdates {
|
||||
// TODO: Should we validate MDM configured on here too?
|
||||
|
||||
if !license.IsPremium() {
|
||||
if !lic.IsPremium() {
|
||||
invalid.Append("windows_updates.deadline_days", ErrMissingLicense.Error())
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1330,7 +1332,7 @@ func (svc *Service) validateMDM(
|
|||
// EndUserAuthentication
|
||||
// only validate SSO settings if they changed
|
||||
if mdm.EndUserAuthentication.SSOProviderSettings != oldMdm.EndUserAuthentication.SSOProviderSettings {
|
||||
if !license.IsPremium() {
|
||||
if !lic.IsPremium() {
|
||||
invalid.Append("end_user_authentication", ErrMissingLicense.Error())
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1366,7 +1368,7 @@ func (svc *Service) validateMDM(
|
|||
// TODO: Should we validate MDM configured on here too?
|
||||
|
||||
if mdm.MacOSMigration.Enable {
|
||||
if !license.IsPremium() {
|
||||
if !lic.IsPremium() {
|
||||
invalid.Append("macos_migration.enable", ErrMissingLicense.Error())
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1423,7 +1425,7 @@ func (svc *Service) validateABMAssignments(
|
|||
ctx context.Context,
|
||||
mdm, oldMdm *fleet.MDM,
|
||||
invalid *fleet.InvalidArgumentError,
|
||||
license *fleet.LicenseInfo,
|
||||
lic *fleet.LicenseInfo,
|
||||
) ([]*fleet.ABMToken, error) {
|
||||
if mdm.DeprecatedAppleBMDefaultTeam != "" && mdm.AppleBusinessManager.Set && mdm.AppleBusinessManager.Valid {
|
||||
invalid.Append("mdm.apple_bm_default_team", fleet.AppleABMDefaultTeamDeprecatedMessage)
|
||||
|
|
@ -1431,7 +1433,7 @@ func (svc *Service) validateABMAssignments(
|
|||
}
|
||||
|
||||
if name := mdm.DeprecatedAppleBMDefaultTeam; name != "" && name != oldMdm.DeprecatedAppleBMDefaultTeam {
|
||||
if !license.IsPremium() {
|
||||
if !lic.IsPremium() {
|
||||
invalid.Append("mdm.apple_bm_default_team", ErrMissingLicense.Error())
|
||||
return nil, nil
|
||||
}
|
||||
|
|
@ -1464,7 +1466,7 @@ func (svc *Service) validateABMAssignments(
|
|||
}
|
||||
|
||||
if mdm.AppleBusinessManager.Set && len(mdm.AppleBusinessManager.Value) > 0 {
|
||||
if !license.IsPremium() {
|
||||
if !lic.IsPremium() {
|
||||
invalid.Append("mdm.apple_business_manager", ErrMissingLicense.Error())
|
||||
return nil, nil
|
||||
}
|
||||
|
|
@ -1525,14 +1527,14 @@ func (svc *Service) validateVPPAssignments(
|
|||
ctx context.Context,
|
||||
volumePurchasingProgramInfo []fleet.MDMAppleVolumePurchasingProgramInfo,
|
||||
invalid *fleet.InvalidArgumentError,
|
||||
license *fleet.LicenseInfo,
|
||||
lic *fleet.LicenseInfo,
|
||||
) (map[uint][]uint, error) {
|
||||
// Allow clearing VPP assignments in free and premium.
|
||||
if len(volumePurchasingProgramInfo) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if !license.IsPremium() {
|
||||
if !lic.IsPremium() {
|
||||
invalid.Append("mdm.volume_purchasing_program", ErrMissingLicense.Error())
|
||||
return nil, nil
|
||||
}
|
||||
|
|
@ -1622,7 +1624,7 @@ func validateSSOProviderSettings(incoming, existing fleet.SSOProviderSettings, i
|
|||
}
|
||||
}
|
||||
|
||||
func validateSSOSettings(p fleet.AppConfig, existing *fleet.AppConfig, invalid *fleet.InvalidArgumentError, license *fleet.LicenseInfo) {
|
||||
func validateSSOSettings(p fleet.AppConfig, existing *fleet.AppConfig, invalid *fleet.InvalidArgumentError, lic *fleet.LicenseInfo) {
|
||||
if p.SSOSettings != nil && p.SSOSettings.EnableSSO {
|
||||
|
||||
var existingSSOProviderSettings fleet.SSOProviderSettings
|
||||
|
|
@ -1631,7 +1633,7 @@ func validateSSOSettings(p fleet.AppConfig, existing *fleet.AppConfig, invalid *
|
|||
}
|
||||
validateSSOProviderSettings(p.SSOSettings.SSOProviderSettings, existingSSOProviderSettings, invalid)
|
||||
|
||||
if !license.IsPremium() {
|
||||
if !lic.IsPremium() {
|
||||
if p.SSOSettings.EnableJITProvisioning {
|
||||
invalid.Append("enable_jit_provisioning", ErrMissingLicense.Error())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1793,7 +1793,7 @@ func (r mdmAppleEnrollResponse) HijackRender(ctx context.Context, w http.Respons
|
|||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
if err := json.NewEncoder(w).Encode(r.SoftwareUpdateRequired); err != nil {
|
||||
endpoint_utils.EncodeError(ctx, ctxerr.New(ctx, "failed to encode software update required"), w)
|
||||
encodeError(ctx, ctxerr.New(ctx, "failed to encode software update required"), w)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/certserial"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
middleware_log "github.com/fleetdm/fleet/v4/server/service/middleware/log"
|
||||
|
|
@ -104,6 +105,10 @@ func authenticatedDevice(svc fleet.Service, logger log.Logger, next endpoint.End
|
|||
}
|
||||
|
||||
ctx = hostctx.NewContext(ctx, host)
|
||||
// Register host as error context provider for ctxerr enrichment
|
||||
hostProvider := &hostctx.HostAttributeProvider{Host: host}
|
||||
ctx = ctxerr.AddErrorContextProvider(ctx, hostProvider)
|
||||
|
||||
instrumentHostLogger(ctx, host.ID)
|
||||
if ac, ok := authz_ctx.FromContext(ctx); ok {
|
||||
ac.SetAuthnMethod(authnMethod)
|
||||
|
|
@ -151,6 +156,10 @@ func authenticatedHost(svc fleet.Service, logger log.Logger, next endpoint.Endpo
|
|||
}
|
||||
|
||||
ctx = hostctx.NewContext(ctx, host)
|
||||
// Register host as error context provider for ctxerr enrichment
|
||||
hostProvider := &hostctx.HostAttributeProvider{Host: host}
|
||||
ctx = ctxerr.AddErrorContextProvider(ctx, hostProvider)
|
||||
|
||||
instrumentHostLogger(ctx, host.ID)
|
||||
if ac, ok := authz_ctx.FromContext(ctx); ok {
|
||||
ac.SetAuthnMethod(authz_ctx.AuthnHostToken)
|
||||
|
|
@ -193,6 +202,10 @@ func authenticatedOrbitHost(
|
|||
}
|
||||
|
||||
ctx = hostctx.NewContext(ctx, host)
|
||||
// Register host as error context provider for ctxerr enrichment
|
||||
hostProvider := &hostctx.HostAttributeProvider{Host: host}
|
||||
ctx = ctxerr.AddErrorContextProvider(ctx, hostProvider)
|
||||
|
||||
instrumentHostLogger(ctx, host.ID)
|
||||
if ac, ok := authz_ctx.FromContext(ctx); ok {
|
||||
ac.SetAuthnMethod(authz_ctx.AuthnOrbitToken)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/mock"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/auth"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
|
||||
kitlog "github.com/go-kit/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -188,7 +187,7 @@ func TestAuthenticatedHost(t *testing.T) {
|
|||
r := &testNodeKeyRequest{NodeKey: tt.nodeKey}
|
||||
_, err := endpoint(ctx, r)
|
||||
if tt.shouldErr {
|
||||
assert.IsType(t, &endpoint_utils.OsqueryError{}, err)
|
||||
assert.IsType(t, &OsqueryError{}, err)
|
||||
} else {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import (
|
|||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/capabilities"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/auth"
|
||||
eu "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
|
|
@ -21,7 +22,31 @@ import (
|
|||
)
|
||||
|
||||
func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
|
||||
return eu.MakeDecoder(iface, jsonDecode, parseCustomTags, isBodyDecoder, decodeBody)
|
||||
return eu.MakeDecoder(iface, jsonDecode, parseCustomTags, isBodyDecoder, decodeBody, fleetQueryDecoder)
|
||||
}
|
||||
|
||||
// fleetQueryDecoder handles fleet-specific query parameter decoding, such as
|
||||
// converting the order_direction string to the fleet.OrderDirection int type.
|
||||
func fleetQueryDecoder(queryTagName, queryVal string, field reflect.Value) (bool, error) {
|
||||
// Only handle int fields for order_direction
|
||||
if field.Kind() != reflect.Int {
|
||||
return false, nil
|
||||
}
|
||||
switch queryTagName {
|
||||
case "order_direction", "inherited_order_direction":
|
||||
var direction int
|
||||
switch queryVal {
|
||||
case "desc":
|
||||
direction = int(fleet.OrderDescending)
|
||||
case "asc":
|
||||
direction = int(fleet.OrderAscending)
|
||||
default:
|
||||
return false, &fleet.BadRequestError{Message: "unknown order_direction: " + queryVal}
|
||||
}
|
||||
field.SetInt(int64(direction))
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// A value that implements bodyDecoder takes control of decoding the request body.
|
||||
|
|
@ -100,7 +125,7 @@ type endpointer struct {
|
|||
|
||||
func (e *endpointer) CallHandlerFunc(f handlerFunc, ctx context.Context, request interface{},
|
||||
svc interface{},
|
||||
) (fleet.Errorer, error) {
|
||||
) (platform_http.Errorer, error) {
|
||||
return f(ctx, request, svc.(fleet.Service))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -290,7 +290,7 @@ func TestEndpointer(t *testing.T) {
|
|||
auth.SetRequestsContexts(svc),
|
||||
),
|
||||
kithttp.ServerErrorHandler(&endpoint_utils.ErrorHandler{Logger: kitlog.NewNopLogger()}),
|
||||
kithttp.ServerErrorEncoder(endpoint_utils.EncodeError),
|
||||
kithttp.ServerErrorEncoder(fleetErrorEncoder),
|
||||
kithttp.ServerAfter(
|
||||
kithttp.SetContentType("application/json; charset=utf-8"),
|
||||
log.LogRequestEnd(kitlog.NewNopLogger()),
|
||||
|
|
@ -410,7 +410,7 @@ func TestEndpointerCustomMiddleware(t *testing.T) {
|
|||
auth.SetRequestsContexts(svc),
|
||||
),
|
||||
kithttp.ServerErrorHandler(&endpoint_utils.ErrorHandler{Logger: kitlog.NewNopLogger()}),
|
||||
kithttp.ServerErrorEncoder(endpoint_utils.EncodeError),
|
||||
kithttp.ServerErrorEncoder(fleetErrorEncoder),
|
||||
kithttp.ServerAfter(
|
||||
kithttp.SetContentType("application/json; charset=utf-8"),
|
||||
log.LogRequestEnd(kitlog.NewNopLogger()),
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ func MakeHandler(
|
|||
auth.SetRequestsContexts(svc),
|
||||
),
|
||||
kithttp.ServerErrorHandler(&endpoint_utils.ErrorHandler{Logger: logger}),
|
||||
kithttp.ServerErrorEncoder(endpoint_utils.EncodeError),
|
||||
kithttp.ServerErrorEncoder(fleetErrorEncoder),
|
||||
kithttp.ServerAfter(
|
||||
kithttp.SetContentType("application/json; charset=utf-8"),
|
||||
log.LogRequestEnd(logger),
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/mdm/assets"
|
||||
mdmlifecycle "github.com/fleetdm/fleet/v4/server/mdm/lifecycle"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
|
||||
"github.com/fleetdm/fleet/v4/server/worker"
|
||||
"github.com/go-kit/log/level"
|
||||
"github.com/gocarina/gocsv"
|
||||
|
|
@ -2319,7 +2318,7 @@ func (r hostsReportResponse) HijackRender(ctx context.Context, w http.ResponseWr
|
|||
var buf bytes.Buffer
|
||||
if err := gocsv.Marshal(r.Hosts, &buf); err != nil {
|
||||
logging.WithErr(ctx, err)
|
||||
endpoint_utils.EncodeError(ctx, ctxerr.New(ctx, "failed to generate CSV file"), w)
|
||||
encodeError(ctx, ctxerr.New(ctx, "failed to generate CSV file"), w)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -2331,7 +2330,7 @@ func (r hostsReportResponse) HijackRender(ctx context.Context, w http.ResponseWr
|
|||
recs, err := csv.NewReader(&buf).ReadAll()
|
||||
if err != nil {
|
||||
logging.WithErr(ctx, err)
|
||||
endpoint_utils.EncodeError(ctx, ctxerr.New(ctx, "failed to generate CSV file"), w)
|
||||
encodeError(ctx, ctxerr.New(ctx, "failed to generate CSV file"), w)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -2352,7 +2351,7 @@ func (r hostsReportResponse) HijackRender(ctx context.Context, w http.ResponseWr
|
|||
// duplicating the list of columns from the Host's struct tags to a
|
||||
// map and keep this in sync, for what is essentially a programmer
|
||||
// mistake that should be caught and corrected early.
|
||||
endpoint_utils.EncodeError(ctx, &fleet.BadRequestError{Message: fmt.Sprintf("invalid column name: %q", col)}, w)
|
||||
encodeError(ctx, &fleet.BadRequestError{Message: fmt.Sprintf("invalid column name: %q", col)}, w)
|
||||
return
|
||||
}
|
||||
outRows[i] = append(outRows[i], rec[colIx])
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ import (
|
|||
|
||||
"github.com/fleetdm/fleet/v4/ee/server/service/hostidentity/httpsig"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/authz"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/token"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
|
|
@ -76,6 +78,10 @@ func AuthenticatedUser(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpo
|
|||
}
|
||||
|
||||
ctx = viewer.NewContext(ctx, *v)
|
||||
// Register viewer as error context provider for ctxerr enrichment
|
||||
ctx = ctxerr.AddErrorContextProvider(ctx, v)
|
||||
// Register viewer as user emailer for logging
|
||||
ctx = logging.WithUserEmailer(ctx, v)
|
||||
if ac, ok := authz.FromContext(ctx); ok {
|
||||
ac.SetAuthnMethod(authz.AuthnUserToken)
|
||||
}
|
||||
|
|
@ -123,6 +129,10 @@ func AuthenticatedUserMiddleware(svc fleet.Service, errHandler errorHandler, nex
|
|||
}
|
||||
|
||||
ctx := viewer.NewContext(r.Context(), *v)
|
||||
// Register viewer as error context provider for ctxerr enrichment
|
||||
ctx = ctxerr.AddErrorContextProvider(ctx, v)
|
||||
// Register viewer as user emailer for logging
|
||||
ctx = logging.WithUserEmailer(ctx, v)
|
||||
if ac, ok := authz.FromContext(r.Context()); ok {
|
||||
ac.SetAuthnMethod(authz.AuthnUserToken)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/token"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
|
|
@ -20,6 +21,10 @@ func SetRequestsContexts(svc fleet.Service) kithttp.RequestFunc {
|
|||
v, err := AuthViewer(ctx, string(bearer), svc)
|
||||
if err == nil {
|
||||
ctx = viewer.NewContext(ctx, *v)
|
||||
// Register viewer as error context provider for ctxerr enrichment
|
||||
ctx = ctxerr.AddErrorContextProvider(ctx, v)
|
||||
// Register viewer as user emailer for logging
|
||||
ctx = logging.WithUserEmailer(ctx, v)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,9 +8,8 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/authz"
|
||||
authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
)
|
||||
|
||||
|
|
@ -32,13 +31,13 @@ func (m *Middleware) AuthzCheck() endpoint.Middleware {
|
|||
|
||||
// If authentication check failed, return that error (so that we log
|
||||
// appropriately).
|
||||
var authFailedError *fleet.AuthFailedError
|
||||
var authRequiredError *fleet.AuthRequiredError
|
||||
var authHeaderRequiredError *fleet.AuthHeaderRequiredError
|
||||
var authFailedError *platform_http.AuthFailedError
|
||||
var authRequiredError *platform_http.AuthRequiredError
|
||||
var authHeaderRequiredError *platform_http.AuthHeaderRequiredError
|
||||
if errors.As(err, &authFailedError) ||
|
||||
errors.As(err, &authRequiredError) ||
|
||||
errors.As(err, &authHeaderRequiredError) ||
|
||||
errors.Is(err, fleet.ErrPasswordResetRequired) {
|
||||
errors.Is(err, platform_http.ErrPasswordResetRequired) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -52,7 +51,7 @@ func (m *Middleware) AuthzCheck() endpoint.Middleware {
|
|||
// marshal to a generic error and log that the check was missed.
|
||||
if !authzctx.Checked() {
|
||||
// Getting to here means there is an authorization-related bug in our code.
|
||||
return nil, authz.CheckMissingWithResponse(response)
|
||||
return nil, platform_http.CheckMissingWithResponse(response)
|
||||
}
|
||||
|
||||
return response, err
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/authz"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
|
@ -33,7 +33,7 @@ func TestAuthzCheckAuthFailed(t *testing.T) {
|
|||
checker := NewMiddleware()
|
||||
|
||||
check := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, fleet.NewAuthFailedError("failed")
|
||||
return nil, platform_http.NewAuthFailedError("failed")
|
||||
}
|
||||
check = checker.AuthzCheck()(check)
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ func TestAuthzCheckAuthRequired(t *testing.T) {
|
|||
checker := NewMiddleware()
|
||||
|
||||
check := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, fleet.NewAuthRequiredError("required")
|
||||
return nil, platform_http.NewAuthRequiredError("required")
|
||||
}
|
||||
check = checker.AuthzCheck()(check)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/license"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/authzcheck"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/ratelimit"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
|
|
@ -86,7 +86,7 @@ func BadRequestErr(publicMsg string, internalErr error) error {
|
|||
if errors.As(internalErr, &opErr) {
|
||||
return fmt.Errorf(publicMsg+", internal: %w", internalErr)
|
||||
}
|
||||
return &fleet.BadRequestError{
|
||||
return &platform_http.BadRequestError{
|
||||
Message: publicMsg,
|
||||
InternalErr: internalErr,
|
||||
}
|
||||
|
|
@ -169,7 +169,11 @@ func DecodeURLTagValue(r *http.Request, field reflect.Value, urlTagValue string,
|
|||
return nil
|
||||
}
|
||||
|
||||
func DecodeQueryTagValue(r *http.Request, fp fieldPair) error {
|
||||
// DomainQueryFieldDecoder decodes a query parameter value into the target field.
|
||||
// It returns true if it handled the field, false if default handling should be used.
|
||||
type DomainQueryFieldDecoder func(queryTagName, queryVal string, field reflect.Value) (handled bool, err error)
|
||||
|
||||
func DecodeQueryTagValue(r *http.Request, fp fieldPair, customDecoder DomainQueryFieldDecoder) error {
|
||||
queryTagValue, ok := fp.Sf.Tag.Lookup("query")
|
||||
|
||||
if ok {
|
||||
|
|
@ -185,7 +189,7 @@ func DecodeQueryTagValue(r *http.Request, fp fieldPair) error {
|
|||
if optional {
|
||||
return nil
|
||||
}
|
||||
return &fleet.BadRequestError{Message: fmt.Sprintf("Param %s is required", queryTagValue)}
|
||||
return &platform_http.BadRequestError{Message: fmt.Sprintf("Param %s is required", queryTagValue)}
|
||||
}
|
||||
field := fp.V
|
||||
if field.Kind() == reflect.Ptr {
|
||||
|
|
@ -193,6 +197,18 @@ func DecodeQueryTagValue(r *http.Request, fp fieldPair) error {
|
|||
field.Set(reflect.New(field.Type().Elem()))
|
||||
field = field.Elem()
|
||||
}
|
||||
|
||||
// Try custom decoder first if provided
|
||||
if customDecoder != nil {
|
||||
handled, err := customDecoder(queryTagValue, queryVal, field)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if handled {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.String:
|
||||
field.SetString(queryVal)
|
||||
|
|
@ -211,22 +227,9 @@ func DecodeQueryTagValue(r *http.Request, fp fieldPair) error {
|
|||
case reflect.Bool:
|
||||
field.SetBool(queryVal == "1" || queryVal == "true")
|
||||
case reflect.Int:
|
||||
queryValInt := 0
|
||||
switch queryTagValue {
|
||||
case "order_direction", "inherited_order_direction":
|
||||
switch queryVal {
|
||||
case "desc":
|
||||
queryValInt = int(fleet.OrderDescending)
|
||||
case "asc":
|
||||
queryValInt = int(fleet.OrderAscending)
|
||||
default:
|
||||
return &fleet.BadRequestError{Message: "unknown order_direction: " + queryVal}
|
||||
}
|
||||
default:
|
||||
queryValInt, err = strconv.Atoi(queryVal)
|
||||
if err != nil {
|
||||
return BadRequestErr("parsing int from query", err)
|
||||
}
|
||||
queryValInt, err := strconv.Atoi(queryVal)
|
||||
if err != nil {
|
||||
return BadRequestErr("parsing int from query", err)
|
||||
}
|
||||
field.SetInt(int64(queryValInt))
|
||||
default:
|
||||
|
|
@ -297,17 +300,17 @@ func (h *ErrorHandler) Handle(ctx context.Context, err error) {
|
|||
logger = log.With(logger, "took", time.Since(startTime))
|
||||
}
|
||||
|
||||
var ewi fleet.ErrWithInternal
|
||||
var ewi platform_http.ErrWithInternal
|
||||
if errors.As(err, &ewi) {
|
||||
logger = log.With(logger, "internal", ewi.Internal())
|
||||
}
|
||||
|
||||
var ewlf fleet.ErrWithLogFields
|
||||
var ewlf platform_http.ErrWithLogFields
|
||||
if errors.As(err, &ewlf) {
|
||||
logger = log.With(logger, ewlf.LogFields()...)
|
||||
}
|
||||
|
||||
var uuider fleet.ErrorUUIDer
|
||||
var uuider platform_http.ErrorUUIDer
|
||||
if errors.As(err, &uuider) {
|
||||
logger = log.With(logger, "uuid", uuider.UUID())
|
||||
}
|
||||
|
|
@ -338,17 +341,14 @@ type requestValidator interface {
|
|||
}
|
||||
|
||||
// MakeDecoder creates a decoder for the type for the struct passed on. If the
|
||||
// struct has at least 1 json tag it'll unmarshall the body. If the struct has
|
||||
// a `url` tag with value list_options it'll gather fleet.ListOptions from the
|
||||
// URL (similarly for host_options, carve_options, user_options that derive
|
||||
// from the common list_options). Note that these behaviors do not work for embedded structs.
|
||||
// struct has at least 1 json tag it'll unmarshall the body. Custom `url` tag
|
||||
// values can be handled by providing a parseCustomTags function. Note that
|
||||
// these behaviors do not work for embedded structs.
|
||||
//
|
||||
// Finally, any other `url` tag will be treated as a path variable (of the form
|
||||
// Any other `url` tag will be treated as a path variable (of the form
|
||||
// /path/{name} in the route's path) from the URL path pattern, and it'll be
|
||||
// decoded and set accordingly. Variables can be optional by setting the tag as
|
||||
// follows: `url:"some-id,optional"`.
|
||||
// The "list_options" are optional by default and it'll ignore the optional
|
||||
// portion of the tag.
|
||||
//
|
||||
// If iface implements the RequestDecoder interface, it returns a function that
|
||||
// calls iface.DecodeRequest(ctx, r) - i.e. the value itself fully controls its
|
||||
|
|
@ -357,12 +357,16 @@ type requestValidator interface {
|
|||
// If iface implements the bodyDecoder interface, it calls iface.DecodeBody
|
||||
// after having decoded any non-body fields (such as url and query parameters)
|
||||
// into the struct.
|
||||
//
|
||||
// The customQueryDecoder parameter allows services to inject domain-specific
|
||||
// query parameter decoding logic.
|
||||
func MakeDecoder(
|
||||
iface interface{},
|
||||
jsonUnmarshal func(body io.Reader, req any) error,
|
||||
parseCustomTags func(urlTagValue string, r *http.Request, field reflect.Value) (bool, error),
|
||||
isBodyDecoder func(reflect.Value) bool,
|
||||
decodeBody func(ctx context.Context, r *http.Request, v reflect.Value, body io.Reader) error,
|
||||
customQueryDecoder DomainQueryFieldDecoder,
|
||||
) kithttp.DecodeRequestFunc {
|
||||
if iface == nil {
|
||||
return func(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
|
|
@ -441,16 +445,16 @@ func MakeDecoder(
|
|||
|
||||
_, jsonExpected := fp.Sf.Tag.Lookup("json")
|
||||
if jsonExpected && nilBody {
|
||||
return nil, &fleet.BadRequestError{Message: "Expected JSON Body"}
|
||||
return nil, &platform_http.BadRequestError{Message: "Expected JSON Body"}
|
||||
}
|
||||
|
||||
isContentJson := r.Header.Get("Content-Type") == "application/json"
|
||||
isCrossSite := r.Header.Get("Origin") != "" || r.Header.Get("Referer") != ""
|
||||
if jsonExpected && isCrossSite && !isContentJson {
|
||||
return nil, fleet.NewUserMessageError(errors.New("Expected Content-Type \"application/json\""), http.StatusUnsupportedMediaType)
|
||||
return nil, platform_http.NewUserMessageError(errors.New("Expected Content-Type \"application/json\""), http.StatusUnsupportedMediaType)
|
||||
}
|
||||
|
||||
err = DecodeQueryTagValue(r, fp)
|
||||
err = DecodeQueryTagValue(r, fp, customQueryDecoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -471,7 +475,7 @@ func MakeDecoder(
|
|||
return nil, err
|
||||
}
|
||||
if val && !fp.V.IsZero() {
|
||||
return nil, &fleet.BadRequestError{Message: fmt.Sprintf(
|
||||
return nil, &platform_http.BadRequestError{Message: fmt.Sprintf(
|
||||
"option %s requires a premium license",
|
||||
fp.Sf.Name,
|
||||
)}
|
||||
|
|
@ -509,17 +513,19 @@ func WriteBrowserSecurityHeaders(w http.ResponseWriter) {
|
|||
}
|
||||
|
||||
type CommonEndpointer[H any] struct {
|
||||
EP Endpointer[H]
|
||||
MakeDecoderFn func(iface interface{}) kithttp.DecodeRequestFunc
|
||||
EncodeFn kithttp.EncodeResponseFunc
|
||||
Opts []kithttp.ServerOption
|
||||
AuthMiddleware endpoint.Middleware
|
||||
Router *mux.Router
|
||||
Versions []string
|
||||
EP Endpointer[H]
|
||||
MakeDecoderFn func(iface any) kithttp.DecodeRequestFunc
|
||||
EncodeFn kithttp.EncodeResponseFunc
|
||||
Opts []kithttp.ServerOption
|
||||
Router *mux.Router
|
||||
Versions []string
|
||||
|
||||
// CustomMiddleware are middlewares that run before AuthMiddleware.
|
||||
// AuthMiddleware is a pre-built authentication middleware.
|
||||
AuthMiddleware endpoint.Middleware
|
||||
|
||||
// CustomMiddleware are middlewares that run before authentication.
|
||||
CustomMiddleware []endpoint.Middleware
|
||||
// CustomMiddlewareAfterAuth are middlewares that run after AuthMiddleware.
|
||||
// CustomMiddlewareAfterAuth are middlewares that run after authentication.
|
||||
CustomMiddlewareAfterAuth []endpoint.Middleware
|
||||
|
||||
startingAtVersion string
|
||||
|
|
@ -529,8 +535,8 @@ type CommonEndpointer[H any] struct {
|
|||
}
|
||||
|
||||
type Endpointer[H any] interface {
|
||||
CallHandlerFunc(f H, ctx context.Context, request interface{}, svc interface{}) (fleet.Errorer, error)
|
||||
Service() interface{}
|
||||
CallHandlerFunc(f H, ctx context.Context, request any, svc any) (platform_http.Errorer, error)
|
||||
Service() any
|
||||
}
|
||||
|
||||
func (e *CommonEndpointer[H]) POST(path string, f H, v interface{}) {
|
||||
|
|
@ -730,6 +736,7 @@ func EncodeCommonResponse(
|
|||
w http.ResponseWriter,
|
||||
response interface{},
|
||||
jsonMarshal func(w http.ResponseWriter, response interface{}) error,
|
||||
domainErrorEncoder DomainErrorEncoder,
|
||||
) error {
|
||||
if cs, ok := response.(cookieSetter); ok {
|
||||
cs.SetCookies(ctx, w)
|
||||
|
|
@ -747,8 +754,8 @@ func EncodeCommonResponse(
|
|||
return err
|
||||
}
|
||||
|
||||
if e, ok := response.(fleet.Errorer); ok && e.Error() != nil {
|
||||
EncodeError(ctx, e.Error(), w)
|
||||
if e, ok := response.(platform_http.Errorer); ok && e.Error() != nil {
|
||||
EncodeError(ctx, e.Error(), w, domainErrorEncoder)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import (
|
|||
"testing"
|
||||
|
||||
authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
kithttp "github.com/go-kit/kit/transport/http"
|
||||
"github.com/gorilla/mux"
|
||||
|
|
@ -16,7 +16,7 @@ import (
|
|||
)
|
||||
|
||||
// testHandlerFunc is a handler function type used for testing.
|
||||
type testHandlerFunc func(ctx context.Context, request any, svc any) (fleet.Errorer, error)
|
||||
type testHandlerFunc func(ctx context.Context, request any) (platform_http.Errorer, error)
|
||||
|
||||
func TestCustomMiddlewareAfterAuth(t *testing.T) {
|
||||
var (
|
||||
|
|
@ -82,7 +82,7 @@ func TestCustomMiddlewareAfterAuth(t *testing.T) {
|
|||
},
|
||||
Router: r,
|
||||
}
|
||||
ce.handleEndpoint("/", func(ctx context.Context, request interface{}, svc any) (fleet.Errorer, error) {
|
||||
ce.handleEndpoint("/", func(ctx context.Context, request any) (platform_http.Errorer, error) {
|
||||
fmt.Printf("handler\n")
|
||||
return nopResponse{}, nil
|
||||
}, nil, "GET")
|
||||
|
|
@ -116,10 +116,10 @@ func (n nopResponse) Error() error {
|
|||
|
||||
type nopEP struct{}
|
||||
|
||||
func (n nopEP) CallHandlerFunc(_ testHandlerFunc, _ context.Context, _ any, _ any) (fleet.Errorer, error) {
|
||||
return nopResponse{}, nil
|
||||
func (n nopEP) CallHandlerFunc(f testHandlerFunc, ctx context.Context, request any, svc any) (platform_http.Errorer, error) {
|
||||
return f(ctx, request)
|
||||
}
|
||||
|
||||
func (n nopEP) Service() interface{} {
|
||||
func (n nopEP) Service() any {
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,13 +4,12 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
||||
kithttp "github.com/go-kit/kit/transport/http"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
|
@ -18,6 +17,11 @@ import (
|
|||
// ErrBadRoute is used for mux errors
|
||||
var ErrBadRoute = errors.New("bad route")
|
||||
|
||||
// DomainErrorEncoder handles domain-specific error encoding.
|
||||
// It returns true if it handled the error, false if default handling should be used.
|
||||
// The encoder should write the appropriate status code and response body.
|
||||
type DomainErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter, enc *json.Encoder, jsonErr *JsonError) (handled bool)
|
||||
|
||||
type JsonError struct {
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code,omitempty"`
|
||||
|
|
@ -67,8 +71,10 @@ type conflictErrorInterface interface {
|
|||
IsConflict() bool
|
||||
}
|
||||
|
||||
// EncodeError encodes error and status header to the client
|
||||
func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
|
||||
// EncodeError encodes error and status header to the client.
|
||||
// The domainEncoder parameter allows services to inject domain-specific error
|
||||
// handling. If nil, only generic error handling is performed.
|
||||
func EncodeError(ctx context.Context, err error, w http.ResponseWriter, domainEncoder DomainErrorEncoder) {
|
||||
ctxerr.Handle(ctx, err)
|
||||
origErr := err
|
||||
|
||||
|
|
@ -78,7 +84,7 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
|
|||
err = ctxerr.Cause(err)
|
||||
|
||||
var uuid string
|
||||
if uuidErr, ok := err.(fleet.ErrorUUIDer); ok {
|
||||
if uuidErr, ok := err.(platform_http.ErrorUUIDer); ok {
|
||||
uuid = uuidErr.UUID()
|
||||
}
|
||||
|
||||
|
|
@ -86,6 +92,13 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
|
|||
UUID: uuid,
|
||||
}
|
||||
|
||||
// Try domain-specific error encoder first
|
||||
if domainEncoder != nil {
|
||||
if handled := domainEncoder(ctx, err, w, enc, &jsonErr); handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
switch e := err.(type) {
|
||||
case validationErrorInterface:
|
||||
if statusErr, ok := e.(interface{ Status() int }); ok {
|
||||
|
|
@ -99,36 +112,6 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
|
|||
jsonErr.Message = "Permission Denied"
|
||||
jsonErr.Errors = e.PermissionError()
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
case MailError:
|
||||
jsonErr.Message = "Mail Error"
|
||||
jsonErr.Errors = e.MailError()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
case *OsqueryError:
|
||||
// osquery expects to receive the node_invalid key when a TLS
|
||||
// request provides an invalid node_key for authentication. It
|
||||
// doesn't use the error message provided, but we provide this
|
||||
// for debugging purposes (and perhaps osquery will use this
|
||||
// error message in the future).
|
||||
|
||||
errMap := map[string]interface{}{
|
||||
"error": e.Error(),
|
||||
"uuid": uuid,
|
||||
}
|
||||
if e.NodeInvalid() { //nolint:gocritic // ignore ifElseChain
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
errMap["node_invalid"] = true
|
||||
} else if e.Status() != 0 {
|
||||
w.WriteHeader(e.Status())
|
||||
} else {
|
||||
// TODO: osqueryError is not always the result of an internal error on
|
||||
// our side, it is also used to represent a client error (invalid data,
|
||||
// e.g. malformed json, carve too large, etc., so 4xx), are we returning
|
||||
// a 500 because of some osquery-specific requirement?
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
enc.Encode(errMap) //nolint:errcheck
|
||||
return
|
||||
case NotFoundErrorInterface:
|
||||
jsonErr.Message = "Resource Not Found"
|
||||
jsonErr.Errors = baseError(e.Error())
|
||||
|
|
@ -153,7 +136,7 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
|
|||
statusCode = http.StatusConflict
|
||||
}
|
||||
w.WriteHeader(statusCode)
|
||||
case *fleet.Error:
|
||||
case *platform_http.Error:
|
||||
jsonErr.Message = e.Error()
|
||||
jsonErr.Code = e.Code
|
||||
w.WriteHeader(http.StatusUnprocessableEntity)
|
||||
|
|
@ -168,7 +151,7 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
|
|||
enc.Encode(jsonErr) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
if fleet.IsForeignKey(err) {
|
||||
if platform_http.IsForeignKey(err) {
|
||||
jsonErr.Message = "Validation Failed"
|
||||
jsonErr.Errors = baseError(err.Error())
|
||||
w.WriteHeader(http.StatusUnprocessableEntity)
|
||||
|
|
@ -186,14 +169,14 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
|
|||
|
||||
// See header documentation
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After)
|
||||
var ewra fleet.ErrWithRetryAfter
|
||||
var ewra platform_http.ErrWithRetryAfter
|
||||
if errors.As(err, &ewra) {
|
||||
w.Header().Add("Retry-After", strconv.Itoa(ewra.RetryAfter()))
|
||||
}
|
||||
|
||||
msg := err.Error()
|
||||
reason := err.Error()
|
||||
var ume *fleet.UserMessageError
|
||||
var ume *platform_http.UserMessageError
|
||||
if errors.As(err, &ume) {
|
||||
if text := http.StatusText(status); text != "" {
|
||||
msg = text
|
||||
|
|
@ -208,53 +191,3 @@ func EncodeError(ctx context.Context, err error, w http.ResponseWriter) {
|
|||
|
||||
enc.Encode(jsonErr) //nolint:errcheck
|
||||
}
|
||||
|
||||
// MailError is set when an error performing mail operations
|
||||
type MailError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e MailError) Error() string {
|
||||
return fmt.Sprintf("a mail error occurred: %s", e.Message)
|
||||
}
|
||||
|
||||
func (e MailError) MailError() []map[string]string {
|
||||
return []map[string]string{
|
||||
{
|
||||
"name": "base",
|
||||
"reason": e.Message,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// OsqueryError is the error returned to osquery agents.
|
||||
type OsqueryError struct {
|
||||
message string
|
||||
nodeInvalid bool
|
||||
StatusCode int
|
||||
fleet.ErrorWithUUID
|
||||
}
|
||||
|
||||
var _ fleet.ErrorUUIDer = (*OsqueryError)(nil)
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *OsqueryError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
// NodeInvalid returns whether the error returned to osquery
|
||||
// should contain the node_invalid property.
|
||||
func (e *OsqueryError) NodeInvalid() bool {
|
||||
return e.nodeInvalid
|
||||
}
|
||||
|
||||
func (e *OsqueryError) Status() int {
|
||||
return e.StatusCode
|
||||
}
|
||||
|
||||
func NewOsqueryError(message string, nodeInvalid bool) *OsqueryError {
|
||||
return &OsqueryError{
|
||||
message: message,
|
||||
nodeInvalid: nodeInvalid,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import (
|
|||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
|
@ -25,7 +25,7 @@ type newAndExciting struct{}
|
|||
func (newAndExciting) Error() string { return "" }
|
||||
|
||||
type notFoundError struct {
|
||||
fleet.ErrorWithUUID
|
||||
platform_http.ErrorWithUUID
|
||||
}
|
||||
|
||||
func (e *notFoundError) Error() string {
|
||||
|
|
@ -36,6 +36,32 @@ func (e *notFoundError) IsNotFound() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// validationError is a test implementation of validationErrorInterface.
|
||||
type validationError struct {
|
||||
errors []map[string]string
|
||||
}
|
||||
|
||||
func (e validationError) Error() string {
|
||||
return "validation failed"
|
||||
}
|
||||
|
||||
func (e validationError) Invalid() []map[string]string {
|
||||
return e.errors
|
||||
}
|
||||
|
||||
// permissionError is a test implementation of permissionErrorInterface.
|
||||
type permissionError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (e permissionError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func (e permissionError) PermissionError() []map[string]string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHandlesErrorsCode(t *testing.T) {
|
||||
errorTests := []struct {
|
||||
name string
|
||||
|
|
@ -44,12 +70,12 @@ func TestHandlesErrorsCode(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
"validation",
|
||||
fleet.NewInvalidArgumentError("a", "b"),
|
||||
validationError{errors: []map[string]string{{"name": "a", "reason": "b"}}},
|
||||
http.StatusUnprocessableEntity,
|
||||
},
|
||||
{
|
||||
"permission",
|
||||
fleet.NewPermissionError("a"),
|
||||
permissionError{message: "a"},
|
||||
http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
|
|
@ -57,21 +83,6 @@ func TestHandlesErrorsCode(t *testing.T) {
|
|||
foreignKeyError{},
|
||||
http.StatusUnprocessableEntity,
|
||||
},
|
||||
{
|
||||
"mail error",
|
||||
MailError{},
|
||||
http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
"osquery error - invalid node",
|
||||
&OsqueryError{nodeInvalid: true},
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
"osquery error - valid node",
|
||||
&OsqueryError{},
|
||||
http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
"data not found",
|
||||
¬FoundError{},
|
||||
|
|
@ -84,7 +95,7 @@ func TestHandlesErrorsCode(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"status coder",
|
||||
fleet.NewAuthFailedError(""),
|
||||
platform_http.NewAuthFailedError(""),
|
||||
http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
|
|
@ -97,7 +108,7 @@ func TestHandlesErrorsCode(t *testing.T) {
|
|||
for _, tt := range errorTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
EncodeError(context.Background(), tt.err, recorder)
|
||||
EncodeError(context.Background(), tt.err, recorder, nil)
|
||||
assert.Equal(t, recorder.Code, tt.code)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ import (
|
|||
microsoft_mdm "github.com/fleetdm/fleet/v4/server/mdm/microsoft"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/service/contract"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
|
||||
"github.com/fleetdm/fleet/v4/server/worker"
|
||||
"github.com/go-kit/log/level"
|
||||
)
|
||||
|
|
@ -65,7 +64,7 @@ func (r EnrollOrbitResponse) HijackRender(ctx context.Context, w http.ResponseWr
|
|||
enc.SetIndent("", " ")
|
||||
|
||||
if err := enc.Encode(r); err != nil {
|
||||
endpoint_utils.EncodeError(ctx, newOsqueryError(fmt.Sprintf("orbit enroll failed: %s", err)), w)
|
||||
encodeError(ctx, newOsqueryError(fmt.Sprintf("orbit enroll failed: %s", err)), w)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/pubsub"
|
||||
"github.com/fleetdm/fleet/v4/server/service/conditional_access_microsoft_proxy"
|
||||
"github.com/fleetdm/fleet/v4/server/service/contract"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
|
||||
"github.com/fleetdm/fleet/v4/server/service/osquery_utils"
|
||||
kithttp "github.com/go-kit/kit/transport/http"
|
||||
"github.com/go-kit/log"
|
||||
|
|
@ -34,12 +33,12 @@ import (
|
|||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
func newOsqueryErrorWithInvalidNode(msg string) *endpoint_utils.OsqueryError {
|
||||
return endpoint_utils.NewOsqueryError(msg, true)
|
||||
func newOsqueryErrorWithInvalidNode(msg string) *OsqueryError {
|
||||
return NewOsqueryError(msg, true)
|
||||
}
|
||||
|
||||
func newOsqueryError(msg string) *endpoint_utils.OsqueryError {
|
||||
return endpoint_utils.NewOsqueryError(msg, false)
|
||||
func newOsqueryError(msg string) *OsqueryError {
|
||||
return NewOsqueryError(msg, false)
|
||||
}
|
||||
|
||||
func (svc *Service) AuthenticateHost(ctx context.Context, nodeKey string) (*fleet.Host, bool, error) {
|
||||
|
|
@ -138,7 +137,7 @@ func (svc *Service) EnrollOsquery(ctx context.Context, enrollSecret, hostIdentif
|
|||
if !canEnroll {
|
||||
deviceCount := "unknown"
|
||||
if lic, _ := license.FromContext(ctx); lic != nil {
|
||||
deviceCount = strconv.Itoa(lic.DeviceCount)
|
||||
deviceCount = strconv.Itoa(lic.GetDeviceCount())
|
||||
}
|
||||
return "", newOsqueryErrorWithInvalidNode(fmt.Sprintf("enroll host failed: maximum number of hosts reached: %s", deviceCount))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/pubsub"
|
||||
"github.com/fleetdm/fleet/v4/server/service/async"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
|
||||
"github.com/fleetdm/fleet/v4/server/service/osquery_utils"
|
||||
"github.com/fleetdm/fleet/v4/server/service/redis_policy_set"
|
||||
"github.com/go-kit/log"
|
||||
|
|
@ -992,7 +991,7 @@ func TestSubmitResultLogsFail(t *testing.T) {
|
|||
// Expect an error when unable to write to logging destination.
|
||||
err = svc.SubmitResultLogs(ctx, results)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, err.(*endpoint_utils.OsqueryError).Status())
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, err.(*OsqueryError).Status())
|
||||
}
|
||||
|
||||
func TestGetQueryNameAndTeamIDFromResult(t *testing.T) {
|
||||
|
|
@ -2908,7 +2907,7 @@ func TestAuthenticationErrors(t *testing.T) {
|
|||
|
||||
_, _, err := svc.AuthenticateHost(ctx, "")
|
||||
require.Error(t, err)
|
||||
require.True(t, err.(*endpoint_utils.OsqueryError).NodeInvalid())
|
||||
require.True(t, err.(*OsqueryError).NodeInvalid())
|
||||
|
||||
ms.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) {
|
||||
return &fleet.Host{ID: 1, HasHostIdentityCert: ptr.Bool(false)}, nil
|
||||
|
|
@ -2929,7 +2928,7 @@ func TestAuthenticationErrors(t *testing.T) {
|
|||
|
||||
_, _, err = svc.AuthenticateHost(ctx, "foo")
|
||||
require.Error(t, err)
|
||||
require.True(t, err.(*endpoint_utils.OsqueryError).NodeInvalid())
|
||||
require.True(t, err.(*OsqueryError).NodeInvalid())
|
||||
|
||||
// return other error
|
||||
ms.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) {
|
||||
|
|
@ -2938,7 +2937,7 @@ func TestAuthenticationErrors(t *testing.T) {
|
|||
|
||||
_, _, err = svc.AuthenticateHost(ctx, "foo")
|
||||
require.NotNil(t, err)
|
||||
require.False(t, err.(*endpoint_utils.OsqueryError).NodeInvalid())
|
||||
require.False(t, err.(*OsqueryError).NodeInvalid())
|
||||
}
|
||||
|
||||
func TestGetHostIdentifier(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mail"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
|
||||
)
|
||||
|
||||
func (svc *Service) NewAppConfig(ctx context.Context, p fleet.AppConfig) (*fleet.AppConfig, error) {
|
||||
|
|
@ -65,7 +64,7 @@ func (svc *Service) sendTestEmail(ctx context.Context, config *fleet.AppConfig)
|
|||
}
|
||||
|
||||
if err := mail.Test(svc.mailService, testMail); err != nil {
|
||||
return endpoint_utils.MailError{Message: err.Error()}
|
||||
return MailError{Message: err.Error()}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -83,12 +82,14 @@ func (svc *Service) License(ctx context.Context) (*fleet.LicenseInfo, error) {
|
|||
}
|
||||
}
|
||||
|
||||
lic, _ := license.FromContext(ctx)
|
||||
licChecker, _ := license.FromContext(ctx)
|
||||
// Type assert to get the concrete type for modification and return
|
||||
lic, _ := licChecker.(*fleet.LicenseInfo)
|
||||
|
||||
// Currently we use the presence of Microsoft Compliance Partner settings
|
||||
// (only configured in cloud instances) to determine if a Fleet instance
|
||||
// is a cloud managed instance.
|
||||
if svc.config.MicrosoftCompliancePartner.IsSet() {
|
||||
if lic != nil && svc.config.MicrosoftCompliancePartner.IsSet() {
|
||||
lic.ManagedCloud = true
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -31,14 +31,15 @@ func (svc *Service) CreateInitialUser(ctx context.Context, p fleet.UserPayload)
|
|||
}
|
||||
|
||||
func (svc *Service) NewUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
|
||||
license, _ := license.FromContext(ctx)
|
||||
if license == nil {
|
||||
licChecker, _ := license.FromContext(ctx)
|
||||
lic, _ := licChecker.(*fleet.LicenseInfo)
|
||||
if lic == nil {
|
||||
return nil, ctxerr.New(ctx, "license not found")
|
||||
}
|
||||
if err := fleet.ValidateUserRoles(true, p, *license); err != nil {
|
||||
if err := fleet.ValidateUserRoles(true, p, *lic); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "validate role")
|
||||
}
|
||||
if !license.IsPremium() {
|
||||
if !lic.IsPremium() {
|
||||
p.MFAEnabled = ptr.Bool(false)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import (
|
|||
)
|
||||
|
||||
func encodeResponse(ctx context.Context, w http.ResponseWriter, response interface{}) error {
|
||||
return endpoint_utils.EncodeCommonResponse(ctx, w, response, jsonMarshal)
|
||||
return endpoint_utils.EncodeCommonResponse(ctx, w, response, jsonMarshal, FleetErrorEncoder)
|
||||
}
|
||||
|
||||
func jsonMarshal(w http.ResponseWriter, response interface{}) error {
|
||||
|
|
|
|||
110
server/service/transport_error.go
Normal file
110
server/service/transport_error.go
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
||||
"github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils"
|
||||
)
|
||||
|
||||
// FleetErrorEncoder handles fleet-specific error encoding for MailError
|
||||
// and OsqueryError.
|
||||
func FleetErrorEncoder(ctx context.Context, err error, w http.ResponseWriter, enc *json.Encoder, jsonErr *endpoint_utils.JsonError) bool {
|
||||
switch e := err.(type) {
|
||||
case MailError:
|
||||
jsonErr.Message = "Mail Error"
|
||||
jsonErr.Errors = []map[string]string{
|
||||
{
|
||||
"name": "base",
|
||||
"reason": e.Message,
|
||||
},
|
||||
}
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
enc.Encode(jsonErr) //nolint:errcheck
|
||||
return true
|
||||
|
||||
case *OsqueryError:
|
||||
// osquery expects to receive the node_invalid key when a TLS
|
||||
// request provides an invalid node_key for authentication. It
|
||||
// doesn't use the error message provided, but we provide this
|
||||
// for debugging purposes (and perhaps osquery will use this
|
||||
// error message in the future).
|
||||
|
||||
errMap := map[string]any{
|
||||
"error": e.Error(),
|
||||
"uuid": jsonErr.UUID,
|
||||
}
|
||||
if e.NodeInvalid() { //nolint:gocritic // ignore ifElseChain
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
errMap["node_invalid"] = true
|
||||
} else if e.Status() != 0 {
|
||||
w.WriteHeader(e.Status())
|
||||
} else {
|
||||
// TODO: osqueryError is not always the result of an internal error on
|
||||
// our side, it is also used to represent a client error (invalid data,
|
||||
// e.g. malformed json, carve too large, etc., so 4xx), are we returning
|
||||
// a 500 because of some osquery-specific requirement?
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
enc.Encode(errMap) //nolint:errcheck
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// MailError is set when an error performing mail operations
|
||||
type MailError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e MailError) Error() string {
|
||||
return fmt.Sprintf("a mail error occurred: %s", e.Message)
|
||||
}
|
||||
|
||||
// OsqueryError is the error returned to osquery agents.
|
||||
type OsqueryError struct {
|
||||
message string
|
||||
nodeInvalid bool
|
||||
StatusCode int
|
||||
platform_http.ErrorWithUUID
|
||||
}
|
||||
|
||||
var _ platform_http.ErrorUUIDer = (*OsqueryError)(nil)
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *OsqueryError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
// NodeInvalid returns whether the error returned to osquery
|
||||
// should contain the node_invalid property.
|
||||
func (e *OsqueryError) NodeInvalid() bool {
|
||||
return e.nodeInvalid
|
||||
}
|
||||
|
||||
func (e *OsqueryError) Status() int {
|
||||
return e.StatusCode
|
||||
}
|
||||
|
||||
func NewOsqueryError(message string, nodeInvalid bool) *OsqueryError {
|
||||
return &OsqueryError{
|
||||
message: message,
|
||||
nodeInvalid: nodeInvalid,
|
||||
}
|
||||
}
|
||||
|
||||
// encodeError is a convenience function that calls endpoint_utils.EncodeError
|
||||
// with the FleetErrorEncoder. Use this for direct error encoding in handlers.
|
||||
func encodeError(ctx context.Context, err error, w http.ResponseWriter) {
|
||||
endpoint_utils.EncodeError(ctx, err, w, FleetErrorEncoder)
|
||||
}
|
||||
|
||||
// fleetErrorEncoder is an adapter that wraps endpoint_utils.EncodeError with
|
||||
// FleetErrorEncoder for use as a kithttp.ErrorEncoder.
|
||||
func fleetErrorEncoder(ctx context.Context, err error, w http.ResponseWriter) {
|
||||
endpoint_utils.EncodeError(ctx, err, w, FleetErrorEncoder)
|
||||
}
|
||||
|
|
@ -433,11 +433,12 @@ func (svc *Service) ModifyUser(ctx context.Context, userID uint, p fleet.UserPay
|
|||
if err := svc.authz.Authorize(ctx, user, fleet.ActionWriteRole); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
license, _ := license.FromContext(ctx)
|
||||
if license == nil {
|
||||
licChecker, _ := license.FromContext(ctx)
|
||||
lic, _ := licChecker.(*fleet.LicenseInfo)
|
||||
if lic == nil {
|
||||
return nil, ctxerr.New(ctx, "license not found")
|
||||
}
|
||||
if err := fleet.ValidateUserRoles(false, p, *license); err != nil {
|
||||
if err := fleet.ValidateUserRoles(false, p, *lic); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "validate role")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue