From 92bc1c650ec6470ad7984a045e7c883d5b314caa Mon Sep 17 00:00:00 2001 From: Victor Lyuboslavsky <2685025+getvictor@users.noreply.github.com> Date: Thu, 26 Feb 2026 17:39:10 -0600 Subject: [PATCH] Move PostJSONWithTimeout to platform/http package and activity cleanup (#40561) **Related issue:** Resolves #38536 - Moved PostJSONWithTimeout to platform/http - Created platform/errors package with only types needed by ctxerr. This way, ctxerr did not need to import fleethttp. - Made activity bounded context use PostJSONWithTimeout directly - Removed some activity types from legacy code that were no longer needed # Checklist for submitter - [ ] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. - Changes file `38536-new-activity-bc` already present, and this is just cleanup from that work. ## Testing - [x] Added/updated automated tests - [x] QA'd all new/changed functionality manually ## Summary by CodeRabbit ## Release Notes * **Refactor** * Reorganized error handling utilities for improved clarity and decoupling. * Consolidated HTTP utilities to centralize JSON posting functionality with timeout support. * Simplified activity service initialization by removing unused internal parameters. * Cleaned up test utilities and removed webhook-related test scaffolding. --- cmd/fleet/serve.go | 3 - server/activity/arch_test.go | 3 +- server/activity/bootstrap/bootstrap.go | 3 +- server/activity/bootstrap/testing.go | 48 ---- .../activity/internal/service/new_activity.go | 4 +- .../internal/service/new_activity_test.go | 185 ++++++++++++- server/activity/internal/service/service.go | 19 +- .../activity/internal/service/service_test.go | 6 +- server/activity/internal/tests/suite_test.go | 4 +- server/activity/providers.go | 3 - server/contexts/ctxerr/ctxerr.go | 6 +- server/contexts/logging/logging.go | 3 +- server/datastore/mysql/in_house_apps_test.go | 2 +- server/datastore/mysql/testing_utils.go | 5 +- server/fleet/activities.go | 46 ---- server/fleet/datastore.go | 9 +- server/fleet/errors.go | 7 +- server/platform/arch_test.go | 16 +- server/platform/errors/errors.go | 41 +++ server/platform/http/errors.go | 39 +-- server/platform/http/post_json.go | 68 +++++ server/platform/mysql/errors.go | 4 +- server/service/activities_test.go | 260 ------------------ server/service/testing_utils.go | 4 - server/utils.go | 68 +---- 25 files changed, 346 insertions(+), 510 deletions(-) delete mode 100644 server/activity/bootstrap/testing.go create mode 100644 server/platform/errors/errors.go create mode 100644 server/platform/http/post_json.go diff --git a/cmd/fleet/serve.go b/cmd/fleet/serve.go index bff66bca15..f430a447ff 100644 --- a/cmd/fleet/serve.go +++ b/cmd/fleet/serve.go @@ -1813,9 +1813,6 @@ func createActivityBoundedContext(svc fleet.Service, dbConns *common_mysql.DBCon dbConns, activityAuthorizer, activityACLAdapter, - func(ctx context.Context, url string, payload any) error { - return server.PostJSONWithTimeout(ctx, url, payload, logger) - }, logger, ) // Create auth middleware for activity bounded context diff --git a/server/activity/arch_test.go b/server/activity/arch_test.go index 4ab2c14040..1117b8de86 100644 --- a/server/activity/arch_test.go +++ b/server/activity/arch_test.go @@ -28,6 +28,7 @@ var ( m + "/server/contexts/logging", m + "/server/contexts/authz", m + "/server/contexts/publicip", + m + "/pkg/fleethttp", } ) @@ -67,7 +68,7 @@ func TestActivityPackageDependencies(t *testing.T) { m + "/server/activity/api", m + "/server/activity/internal/types", m + "/server/activity/internal/testutils", - m + "/server/platform/http", + m + "/server/platform/errors", m + "/server/platform/logging", m + "/server/platform/mysql", m + "/server/platform/mysql/testing_utils", diff --git a/server/activity/bootstrap/bootstrap.go b/server/activity/bootstrap/bootstrap.go index 084a426d14..176a2da7ab 100644 --- a/server/activity/bootstrap/bootstrap.go +++ b/server/activity/bootstrap/bootstrap.go @@ -20,11 +20,10 @@ func New( dbConns *platform_mysql.DBConnections, authorizer platform_authz.Authorizer, providers activity.DataProviders, - webhookSendFn activity.WebhookSendFunc, logger *slog.Logger, ) (api.Service, func(authMiddleware endpoint.Middleware) eu.HandlerRoutesFunc) { ds := mysql.NewDatastore(dbConns, logger) - svc := service.NewService(authorizer, ds, providers, webhookSendFn, logger) + svc := service.NewService(authorizer, ds, providers, logger) routesFn := func(authMiddleware endpoint.Middleware) eu.HandlerRoutesFunc { return service.GetRoutes(svc, authMiddleware) diff --git a/server/activity/bootstrap/testing.go b/server/activity/bootstrap/testing.go deleted file mode 100644 index 82333f7063..0000000000 --- a/server/activity/bootstrap/testing.go +++ /dev/null @@ -1,48 +0,0 @@ -package bootstrap - -import ( - "context" - "log/slog" - "time" - - "github.com/fleetdm/fleet/v4/server/activity" - "github.com/fleetdm/fleet/v4/server/activity/api" - "github.com/fleetdm/fleet/v4/server/activity/internal/service" - "github.com/fleetdm/fleet/v4/server/activity/internal/types" - platform_authz "github.com/fleetdm/fleet/v4/server/platform/authz" -) - -// NewForUnitTests creates an activity NewActivityService backed by a noop store (no database required). -func NewForUnitTests( - providers activity.DataProviders, - webhookSendFn activity.WebhookSendFunc, - logger *slog.Logger, -) api.NewActivityService { - return service.NewService(&noopAuthorizer{}, &noopStore{}, providers, webhookSendFn, logger) -} - -// noopAuthorizer allows all actions (appropriate for unit tests). -type noopAuthorizer struct{} - -func (a *noopAuthorizer) Authorize(_ context.Context, _ platform_authz.AuthzTyper, _ platform_authz.Action) error { - return nil -} - -// noopStore is a datastore that does nothing (appropriate for unit tests that only need webhook behavior). -type noopStore struct{} - -func (s *noopStore) ListActivities(_ context.Context, _ types.ListOptions) ([]*api.Activity, *api.PaginationMetadata, error) { - return nil, nil, nil -} - -func (s *noopStore) ListHostPastActivities(_ context.Context, _ uint, _ types.ListOptions) ([]*api.Activity, *api.PaginationMetadata, error) { - return nil, nil, nil -} - -func (s *noopStore) MarkActivitiesAsStreamed(_ context.Context, _ []uint) error { - return nil -} - -func (s *noopStore) NewActivity(_ context.Context, _ *api.User, _ api.ActivityDetails, _ []byte, _ time.Time) error { - return nil -} diff --git a/server/activity/internal/service/new_activity.go b/server/activity/internal/service/new_activity.go index ac8edc18fd..e01ccaa7df 100644 --- a/server/activity/internal/service/new_activity.go +++ b/server/activity/internal/service/new_activity.go @@ -118,7 +118,7 @@ func (s *Service) fireActivityWebhook( retryStrategy.MaxElapsedTime = 30 * time.Minute err := backoff.Retry( func() error { - if err := s.webhookSendFn( + if err := platformhttp.PostJSONWithTimeout( spanCtx, webhookURL, &webhookPayload{ Timestamp: timestamp, ActorFullName: userName, @@ -126,7 +126,7 @@ func (s *Service) fireActivityWebhook( ActorEmail: userEmail, Type: activityType, Details: (*json.RawMessage)(&detailsBytes), - }, + }, s.logger, ); err != nil { var statusCoder kithttp.StatusCoder if errors.As(err, &statusCoder) && statusCoder.StatusCode() == http.StatusTooManyRequests { diff --git a/server/activity/internal/service/new_activity_test.go b/server/activity/internal/service/new_activity_test.go index ba2c48f308..765ae790b3 100644 --- a/server/activity/internal/service/new_activity_test.go +++ b/server/activity/internal/service/new_activity_test.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "log/slog" + "net/http" + "net/http/httptest" "testing" "time" @@ -78,8 +80,7 @@ func (a activatorActivity) ActivateNextUpcomingActivityArgs() (uint, string) { } func newTestService(ds types.Datastore, providers activity.DataProviders) *Service { - noopWebhookSend := func(_ context.Context, _ string, _ any) error { return nil } - return NewService(&mockAuthorizer{}, ds, providers, noopWebhookSend, slog.New(slog.DiscardHandler)) + return NewService(&mockAuthorizer{}, ds, providers, slog.New(slog.DiscardHandler)) } func TestNewActivityStoresWithWebhookContextKey(t *testing.T) { @@ -213,3 +214,183 @@ func TestNewActivityNilUser(t *testing.T) { require.True(t, ds.newActivityCalled) assert.Nil(t, ds.lastUser) } + +// newTestServiceWithWebhook creates a service configured for webhook delivery tests. +func newTestServiceWithWebhook(ds types.Datastore, providers activity.DataProviders) *Service { + return NewService(&mockAuthorizer{}, ds, providers, slog.New(slog.DiscardHandler)) +} + +func TestNewActivityWebhook(t *testing.T) { + t.Parallel() + + webhookChannel := make(chan struct{}, 1) + var webhookBody webhookPayload + fail429 := false + + startMockServer := func(t *testing.T) string { + srv := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + webhookBody = webhookPayload{} + if r.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + switch r.URL.Path { + case "/ok": + err := json.NewDecoder(r.Body).Decode(&webhookBody) + if err != nil { + t.Log(err) + w.WriteHeader(http.StatusBadRequest) + } + case "/error": + webhookBody.Type = "error" + w.WriteHeader(http.StatusTeapot) + case "/429": + fail429 = !fail429 + if fail429 { + w.WriteHeader(http.StatusTooManyRequests) + return + } + err := json.NewDecoder(r.Body).Decode(&webhookBody) + if err != nil { + t.Log(err) + w.WriteHeader(http.StatusBadRequest) + } + default: + w.WriteHeader(http.StatusNotFound) + return + } + webhookChannel <- struct{}{} + }, + ), + ) + t.Cleanup(srv.Close) + return srv.URL + } + + mockURL := startMockServer(t) + testURL := mockURL + + ds := &newActivityMockDatastore{} + providers := &newActivityMockProviders{ + mockDataProviders: mockDataProviders{ + mockUserProvider: &mockUserProvider{}, + mockHostProvider: &mockHostProvider{}, + webhookConfig: &activity.ActivitiesWebhookSettings{ + Enable: true, + DestinationURL: testURL, + }, + }, + } + + svc := newTestServiceWithWebhook(ds, providers) + + tests := []struct { + name string + user *api.User + url string + doError bool + }{ + { + name: "nil user", + url: mockURL + "/ok", + user: nil, + }, + { + name: "real user", + url: mockURL + "/ok", + user: &api.User{ + ID: 1, + Name: "testUser", + Email: "testUser@example.com", + }, + }, + { + name: "error", + url: mockURL + "/error", + doError: true, + }, + { + name: "429", + url: mockURL + "/429", + user: &api.User{ + ID: 2, + Name: "testUserRetry", + Email: "testUserRetry@example.com", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ds.newActivityCalled = false + providers.webhookConfig.DestinationURL = tt.url + startTime := time.Now() + act := simpleActivity{Name: tt.name} + err := svc.NewActivity(t.Context(), tt.user, act) + require.NoError(t, err) + select { + case <-time.After(3 * time.Second): + t.Error("timeout waiting for webhook") + case <-webhookChannel: + if tt.doError { + assert.Equal(t, "error", webhookBody.Type) + } else { + endTime := time.Now() + assert.False( + t, webhookBody.Timestamp.Before(startTime), "timestamp %s is before start time %s", + webhookBody.Timestamp.String(), startTime.String(), + ) + assert.False(t, webhookBody.Timestamp.After(endTime)) + if tt.user == nil { + assert.Nil(t, webhookBody.ActorFullName) + assert.Nil(t, webhookBody.ActorID) + assert.Nil(t, webhookBody.ActorEmail) + } else { + require.NotNil(t, webhookBody.ActorFullName) + assert.Equal(t, tt.user.Name, *webhookBody.ActorFullName) + require.NotNil(t, webhookBody.ActorID) + assert.Equal(t, tt.user.ID, *webhookBody.ActorID) + require.NotNil(t, webhookBody.ActorEmail) + assert.Equal(t, tt.user.Email, *webhookBody.ActorEmail) + } + assert.Equal(t, act.ActivityName(), webhookBody.Type) + var details map[string]string + require.NoError(t, json.Unmarshal(*webhookBody.Details, &details)) + assert.Len(t, details, 1) + assert.Equal(t, tt.name, details["name"]) + } + } + require.True(t, ds.newActivityCalled) + }) + } +} + +func TestNewActivityWebhookDisabled(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("webhook server should not be called when webhook is disabled") + }), + ) + t.Cleanup(srv.Close) + + ds := &newActivityMockDatastore{} + providers := &newActivityMockProviders{ + mockDataProviders: mockDataProviders{ + mockUserProvider: &mockUserProvider{}, + mockHostProvider: &mockHostProvider{}, + webhookConfig: &activity.ActivitiesWebhookSettings{ + Enable: false, + DestinationURL: srv.URL, + }, + }, + } + + svc := newTestServiceWithWebhook(ds, providers) + err := svc.NewActivity(t.Context(), &api.User{ID: 1}, simpleActivity{Name: "no webhook"}) + require.NoError(t, err) + require.True(t, ds.newActivityCalled) +} diff --git a/server/activity/internal/service/service.go b/server/activity/internal/service/service.go index 5a2d2860a5..0ed4d3546d 100644 --- a/server/activity/internal/service/service.go +++ b/server/activity/internal/service/service.go @@ -46,11 +46,10 @@ func applyListOptionsDefaults(opt *api.ListOptions, defaultOrderKey string) { // Service is the activity bounded context service implementation. type Service struct { - authz platform_authz.Authorizer - store types.Datastore - providers activity.DataProviders - webhookSendFn activity.WebhookSendFunc - logger *slog.Logger + authz platform_authz.Authorizer + store types.Datastore + providers activity.DataProviders + logger *slog.Logger } // NewService creates a new activity service. @@ -58,15 +57,13 @@ func NewService( authz platform_authz.Authorizer, store types.Datastore, providers activity.DataProviders, - webhookSendFn activity.WebhookSendFunc, logger *slog.Logger, ) *Service { return &Service{ - authz: authz, - store: store, - providers: providers, - webhookSendFn: webhookSendFn, - logger: logger, + authz: authz, + store: store, + providers: providers, + logger: logger, } } diff --git a/server/activity/internal/service/service_test.go b/server/activity/internal/service/service_test.go index b195fde7a2..97781f6225 100644 --- a/server/activity/internal/service/service_test.go +++ b/server/activity/internal/service/service_test.go @@ -130,8 +130,7 @@ func setupTest(opts ...func(*testSetup)) *testSetup { for _, opt := range opts { opt(ts) } - noopWebhookSend := func(_ context.Context, _ string, _ any) error { return nil } - ts.svc = NewService(ts.authz, ts.ds, ts.providers, noopWebhookSend, slog.New(slog.DiscardHandler)) + ts.svc = NewService(ts.authz, ts.ds, ts.providers, slog.New(slog.DiscardHandler)) return ts } @@ -532,9 +531,8 @@ func newTestActivity(id uint, actorName string, actorID uint, actType, details s func TestStreamActivities(t *testing.T) { t.Parallel() - noopWebhookSend := func(_ context.Context, _ string, _ any) error { return nil } newStreamingService := func(ds *mockStreamingDatastore) *Service { - return NewService(&mockAuthorizer{}, ds, &mockDataProviders{mockUserProvider: &mockUserProvider{}, mockHostProvider: &mockHostProvider{}}, noopWebhookSend, slog.New(slog.DiscardHandler)) + return NewService(&mockAuthorizer{}, ds, &mockDataProviders{mockUserProvider: &mockUserProvider{}, mockHostProvider: &mockHostProvider{}}, slog.New(slog.DiscardHandler)) } t.Run("basic streaming", func(t *testing.T) { diff --git a/server/activity/internal/tests/suite_test.go b/server/activity/internal/tests/suite_test.go index 6a4fae592b..832cc89044 100644 --- a/server/activity/internal/tests/suite_test.go +++ b/server/activity/internal/tests/suite_test.go @@ -1,7 +1,6 @@ package tests import ( - "context" "encoding/json" "net/http" "net/http/httptest" @@ -38,8 +37,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSuite { providers := newMockDataProviders() // Create service - noopWebhookSend := func(_ context.Context, _ string, _ any) error { return nil } - svc := service.NewService(authorizer, ds, providers, noopWebhookSend, tdb.Logger) + svc := service.NewService(authorizer, ds, providers, tdb.Logger) // Create router with routes router := mux.NewRouter() diff --git a/server/activity/providers.go b/server/activity/providers.go index 123f966027..9378264665 100644 --- a/server/activity/providers.go +++ b/server/activity/providers.go @@ -10,9 +10,6 @@ type UpcomingActivityActivator interface { ActivateNextUpcomingActivity(ctx context.Context, hostID uint, fromCompletedExecID string) error } -// WebhookSendFunc is the function signature for sending a JSON payload to a URL. -type WebhookSendFunc = func(ctx context.Context, url string, payload any) error - // DataProviders combines all external dependency interfaces for the activity // bounded context. The ACL adapter implements this single interface. type DataProviders interface { diff --git a/server/contexts/ctxerr/ctxerr.go b/server/contexts/ctxerr/ctxerr.go index c56dc175b8..9d7c69630d 100644 --- a/server/contexts/ctxerr/ctxerr.go +++ b/server/contexts/ctxerr/ctxerr.go @@ -21,7 +21,7 @@ import ( "strings" "time" - platform_http "github.com/fleetdm/fleet/v4/server/platform/http" + platform_errors "github.com/fleetdm/fleet/v4/server/platform/errors" "github.com/getsentry/sentry-go" "go.elastic.co/apm/v2" "go.opentelemetry.io/otel/attribute" @@ -202,7 +202,7 @@ func Wrapf(ctx context.Context, cause error, format string, args ...interface{}) // Cause returns the root error in err's chain. func Cause(err error) error { - return platform_http.Cause(err) + return platform_errors.Cause(err) } // FleetCause is similar to Cause, but returns the root-most @@ -407,7 +407,7 @@ func isClientError(err error) bool { // Check for explicit client error interface. All 4xx error types // (not found, already exists, conflict, validation, permission, // bad request, foreign key, etc.) should implement this interface. - var clientErr platform_http.ErrWithIsClientError + var clientErr platform_errors.ErrWithIsClientError if errors.As(err, &clientErr) { return clientErr.IsClientError() } diff --git a/server/contexts/logging/logging.go b/server/contexts/logging/logging.go index 6ea855cc4f..588e999a0d 100644 --- a/server/contexts/logging/logging.go +++ b/server/contexts/logging/logging.go @@ -8,6 +8,7 @@ import ( "sync" "time" + platform_errors "github.com/fleetdm/fleet/v4/server/platform/errors" platform_http "github.com/fleetdm/fleet/v4/server/platform/http" kithttp "github.com/go-kit/kit/transport/http" ) @@ -233,7 +234,7 @@ func (l *LoggingContext) setLevelError() bool { } if len(l.Errs) == 1 { - var ew platform_http.ErrWithIsClientError + var ew platform_errors.ErrWithIsClientError if errors.As(l.Errs[0], &ew) && ew.IsClientError() { return false } diff --git a/server/datastore/mysql/in_house_apps_test.go b/server/datastore/mysql/in_house_apps_test.go index 9c69f269a9..df17d1cbc7 100644 --- a/server/datastore/mysql/in_house_apps_test.go +++ b/server/datastore/mysql/in_house_apps_test.go @@ -1736,7 +1736,7 @@ func testSoftwareTitleDisplayNameInHouse(t *testing.T, ds *Datastore) { } func testInHouseAppsCancelledOnUnenroll(t *testing.T, ds *Datastore) { - ctx := context.WithValue(context.Background(), fleet.ActivityWebhookContextKey, true) + ctx := t.Context() test.CreateInsertGlobalVPPToken(t, ds) user := test.NewUser(t, ds, "Alice", "alice@example.com", true) diff --git a/server/datastore/mysql/testing_utils.go b/server/datastore/mysql/testing_utils.go index 19aecbe732..2ccb4a3460 100644 --- a/server/datastore/mysql/testing_utils.go +++ b/server/datastore/mysql/testing_utils.go @@ -26,7 +26,6 @@ import ( "time" "github.com/WatchBeam/clock" - "github.com/fleetdm/fleet/v4/server" "github.com/fleetdm/fleet/v4/server/acl/activityacl" activity_api "github.com/fleetdm/fleet/v4/server/activity/api" activity_bootstrap "github.com/fleetdm/fleet/v4/server/activity/bootstrap" @@ -1033,9 +1032,7 @@ func NewTestActivityService(t testing.TB, ds *Datastore) activity_api.Service { // Create service via bootstrap (the public API for creating the bounded context) discardLogger := slog.New(slog.DiscardHandler) - svc, _ := activity_bootstrap.New(dbConns, &testingAuthorizer{}, aclAdapter, func(ctx context.Context, url string, payload any) error { - return server.PostJSONWithTimeout(ctx, url, payload, discardLogger) - }, discardLogger) + svc, _ := activity_bootstrap.New(dbConns, &testingAuthorizer{}, aclAdapter, discardLogger) return svc } diff --git a/server/fleet/activities.go b/server/fleet/activities.go index 0902c31e1c..8939c2f451 100644 --- a/server/fleet/activities.go +++ b/server/fleet/activities.go @@ -6,23 +6,9 @@ import ( "time" ) -type ContextKey string - // NewActivityFunc is the function signature for creating a new activity. type NewActivityFunc func(ctx context.Context, user *User, activity ActivityDetails) error -type ActivityWebhookPayload struct { - Timestamp time.Time `json:"timestamp"` - ActorFullName *string `json:"actor_full_name"` - ActorID *uint `json:"actor_id"` - ActorEmail *string `json:"actor_email"` - Type string `json:"type"` - Details *json.RawMessage `json:"details"` -} - -// ActivityWebhookContextKey is the context key to indicate that the activity webhook has been processed before saving the activity. -const ActivityWebhookContextKey = ContextKey("ActivityWebhook") - type Activity struct { CreateTimestamp @@ -263,38 +249,6 @@ type ActivityDetails interface { ActivityName() string } -// ActivityHosts is the optional additional interface that can be implemented -// by activities that are related to hosts. -type ActivityHosts interface { - ActivityDetails - HostIDs() []uint -} - -// AutomatableActivity is the optional additional interface that can be implemented -// by activities that are sometimes the result of automation ("Fleet did X"), starting with -// install/script run policy automations -type AutomatableActivity interface { - ActivityDetails - WasFromAutomation() bool -} - -// ActivityHostOnly is the optional additional interface that can be implemented by activities that -// we want to exclude from the global activity feed, and only show on the Hosts details page -type ActivityHostOnly interface { - ActivityDetails - HostOnly() bool -} - -// ActivityActivator is the optional additional interface that can be implemented by activities that -// may require activating the next upcoming activity when it gets created. Most upcoming activities get -// activated when the result of the previous one completes (such as scripts and software installs), but -// some can only be activated when the activity gets recorded (such as VPP and in-house apps). -type ActivityActivator interface { - ActivityDetails - MustActivateNextUpcomingActivity() bool - ActivateNextUpcomingActivityArgs() (hostID uint, cmdUUID string) -} - type ActivityTypeEnabledActivityAutomations struct { WebhookUrl string `json:"webhook_url"` } diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 12b5ad8a5f..844f86b98b 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -18,6 +18,7 @@ import ( "github.com/fleetdm/fleet/v4/server/mdm/nanodep/godep" "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm" "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/storage" + platform_errors "github.com/fleetdm/fleet/v4/server/platform/errors" platform_http "github.com/fleetdm/fleet/v4/server/platform/http" "github.com/jmoiron/sqlx" ) @@ -2916,11 +2917,11 @@ const ( // same in both (the other is currently NotFound), and ideally we'd just have // one of those interfaces. -// NotFoundError is an alias for platform_http.NotFoundError. -type NotFoundError = platform_http.NotFoundError +// NotFoundError is an alias for platform_errors.NotFoundError. +type NotFoundError = platform_errors.NotFoundError -// IsNotFound is an alias for platform_http.IsNotFound. -var IsNotFound = platform_http.IsNotFound +// IsNotFound is an alias for platform_errors.IsNotFound. +var IsNotFound = platform_errors.IsNotFound // AlreadyExistsError is an alias for platform_http.AlreadyExistsError. type AlreadyExistsError = platform_http.AlreadyExistsError diff --git a/server/fleet/errors.go b/server/fleet/errors.go index ed1e6689c1..dbf040e53c 100644 --- a/server/fleet/errors.go +++ b/server/fleet/errors.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + platform_errors "github.com/fleetdm/fleet/v4/server/platform/errors" platform_http "github.com/fleetdm/fleet/v4/server/platform/http" "github.com/rs/zerolog" ) @@ -51,8 +52,8 @@ type ErrWithLogFields = platform_http.ErrWithLogFields // ErrWithRetryAfter is an alias for platform_http.ErrWithRetryAfter. type ErrWithRetryAfter = platform_http.ErrWithRetryAfter -// ErrWithIsClientError is an alias for platform_http.ErrWithIsClientError. -type ErrWithIsClientError = platform_http.ErrWithIsClientError +// ErrWithIsClientError is an alias for platform_errors.ErrWithIsClientError. +type ErrWithIsClientError = platform_errors.ErrWithIsClientError type invalidArgWithStatusError struct { InvalidArgumentError @@ -400,7 +401,7 @@ func GetJSONUnknownField(err error) *string { } // Cause returns the root error in err's chain. -var Cause = platform_http.Cause +var Cause = platform_errors.Cause // FleetdError is an error that can be reported by any of the fleetd // components. diff --git a/server/platform/arch_test.go b/server/platform/arch_test.go index bc7df8ac6d..06caf750c4 100644 --- a/server/platform/arch_test.go +++ b/server/platform/arch_test.go @@ -21,6 +21,7 @@ func TestPlatformPackageDependencies(t *testing.T) { // Platform packages can depend on each other m+"/server/platform...", // Infra packages + m+"/pkg/fleethttp", m+"/server/contexts/authz", m+"/server/contexts/ctxerr", m+"/server/contexts/license", @@ -41,7 +42,8 @@ func TestEndpointerPackageDependencies(t *testing.T) { IgnoreDeps( // Platform packages m+"/server/platform...", - // Other infra packages + // Infra packages + m+"/pkg/fleethttp", m+"/server/contexts/authz", m+"/server/contexts/ctxerr", m+"/server/contexts/license", @@ -56,7 +58,11 @@ func TestHTTPPackageDependencies(t *testing.T) { archtest.NewPackageTest(t, m+"/server/platform/http"). OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)). WithTests(). - ShouldNotDependOn(m + "/..."). + ShouldNotDependOn(m+"/..."). + IgnoreDeps( + m+"/pkg/fleethttp", + m+"/server/platform/errors", + ). Check() } @@ -65,10 +71,13 @@ func TestAuthzCheckPackageDependencies(t *testing.T) { archtest.NewPackageTest(t, m+"/server/platform/middleware/authzcheck"). OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)). WithTests(). + ShouldNotDependOn(m+"/..."). IgnoreDeps( // Platform packages + m+"/server/platform/errors", m+"/server/platform/http", // Other infra packages + m+"/pkg/fleethttp", m+"/server/contexts/authz", ). Check() @@ -82,6 +91,7 @@ func TestRatelimitPackageDependencies(t *testing.T) { ShouldNotDependOn(m+"/..."). IgnoreDeps( // Platform packages + m+"/server/platform/errors", m+"/server/platform/http", // Other infra packages m+"/server/contexts/authz", @@ -103,7 +113,7 @@ func TestMysqlPackageDependencies(t *testing.T) { // Ignore our own packages m+"/server/platform/mysql...", // Other infra packages - m+"/server/platform/http", + m+"/server/platform/errors", m+"/server/platform/logging", m+"/server/contexts/ctxerr", ). diff --git a/server/platform/errors/errors.go b/server/platform/errors/errors.go new file mode 100644 index 0000000000..6dba89d53d --- /dev/null +++ b/server/platform/errors/errors.go @@ -0,0 +1,41 @@ +// Package errors provides error classification primitives used across the +// codebase. These are intentionally kept free of HTTP or other transport +// dependencies so that low-level packages (datastores, context helpers) can +// use them without pulling in higher-level concerns. +package errors + +import "errors" + +// Cause returns the root error in err's chain. +func Cause(err error) error { + for { + uerr := errors.Unwrap(err) + if uerr == nil { + return err + } + err = uerr + } +} + +// ErrWithIsClientError is an interface for errors that explicitly specify +// whether they are client errors or not. By default, errors are treated as +// server errors. +type ErrWithIsClientError interface { + error + IsClientError() bool +} + +// NotFoundError is an interface for errors when a resource cannot be found. +type NotFoundError interface { + error + IsNotFound() bool +} + +// IsNotFound returns true if err is a not-found error. +func IsNotFound(err error) bool { + var nfe NotFoundError + if errors.As(err, &nfe) { + return nfe.IsNotFound() + } + return false +} diff --git a/server/platform/http/errors.go b/server/platform/http/errors.go index 01702c66f1..e3be307c0d 100644 --- a/server/platform/http/errors.go +++ b/server/platform/http/errors.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/docker/go-units" + platform_errors "github.com/fleetdm/fleet/v4/server/platform/errors" "github.com/google/uuid" ) @@ -174,7 +175,7 @@ func IsJSONUnknownFieldError(err error) bool { // GetJSONUnknownField returns the unknown field name from a JSON unknown field error. func GetJSONUnknownField(err error) *string { - errCause := Cause(err) + errCause := platform_errors.Cause(err) if IsJSONUnknownFieldError(errCause) { substr := rxJSONUnknownField.FindStringSubmatch(errCause.Error()) return &substr[1] @@ -186,7 +187,7 @@ func GetJSONUnknownField(err error) *string { // root cause is one of the supported types, otherwise it returns the error // message. func (e UserMessageError) UserMessage() string { - cause := Cause(e.error) + cause := platform_errors.Cause(e.error) switch cause := cause.(type) { case *json.UnmarshalTypeError: var sb strings.Builder @@ -213,17 +214,6 @@ func (e UserMessageError) UserMessage() string { } } -// Cause returns the root error in err's chain. -func Cause(err error) error { - for { - uerr := errors.Unwrap(err) - if uerr == nil { - return err - } - err = uerr - } -} - // ErrWithRetryAfter is an interface for errors that should set a specific HTTP // Header Retry-After value (see // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After) @@ -248,21 +238,6 @@ func IsForeignKey(err error) bool { return false } -// NotFoundError is an interface for errors when a resource cannot be found. -type NotFoundError interface { - error - IsNotFound() bool -} - -// IsNotFound returns true if err is a not-found error. -func IsNotFound(err error) bool { - var nfe NotFoundError - if errors.As(err, &nfe) { - return nfe.IsNotFound() - } - return false -} - // AlreadyExistsError is an interface for errors when a resource already exists. type AlreadyExistsError interface { error @@ -282,14 +257,6 @@ func (e *Error) Error() string { return e.Message } -// ErrWithIsClientError is an interface for errors that explicitly specify -// whether they are client errors or not. By default, errors are treated as -// server errors. -type ErrWithIsClientError interface { - error - IsClientError() bool -} - // AuthFailedError is returned when authentication fails. type AuthFailedError struct { // internal is the reason that should only be logged internally diff --git a/server/platform/http/post_json.go b/server/platform/http/post_json.go new file mode 100644 index 0000000000..595e8ce506 --- /dev/null +++ b/server/platform/http/post_json.go @@ -0,0 +1,68 @@ +package http + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "time" + + "github.com/fleetdm/fleet/v4/pkg/fleethttp" +) + +// errWithStatus is an error with a particular status code. +type errWithStatus struct { + err string + statusCode int +} + +// Error implements the error interface. +func (e *errWithStatus) Error() string { + return e.err +} + +// StatusCode implements the StatusCoder interface for returning custom status codes. +func (e *errWithStatus) StatusCode() int { + return e.statusCode +} + +// PostJSONWithTimeout marshals v as JSON and POSTs it to the given URL with a 30-second timeout. +func PostJSONWithTimeout(ctx context.Context, url string, v any, logger *slog.Logger) error { + jsonBytes, err := json.Marshal(v) + if err != nil { + return err + } + + client := fleethttp.NewClient(fleethttp.WithTimeout(30 * time.Second)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBytes)) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to POST to %s: %s, request-size=%d", MaskSecretURLParams(url), MaskURLError(err), len(jsonBytes)) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode > 299 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 513)) + bodyStr := string(body) + if len(bodyStr) > 512 { + bodyStr = bodyStr[:512] + } + logger.DebugContext(ctx, "non-success response from POST", + "url", MaskSecretURLParams(url), + "status_code", resp.StatusCode, + "body", bodyStr, + ) + return &errWithStatus{err: fmt.Sprintf("error posting to %s", MaskSecretURLParams(url)), statusCode: resp.StatusCode} + } + + return nil +} diff --git a/server/platform/mysql/errors.go b/server/platform/mysql/errors.go index 2f6e16466a..a91effd856 100644 --- a/server/platform/mysql/errors.go +++ b/server/platform/mysql/errors.go @@ -6,7 +6,7 @@ import ( "fmt" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - platform_http "github.com/fleetdm/fleet/v4/server/platform/http" + platform_errors "github.com/fleetdm/fleet/v4/server/platform/errors" "github.com/go-sql-driver/mysql" ) @@ -18,7 +18,7 @@ type NotFoundError struct { } // Compile-time interface check. -var _ platform_http.NotFoundError = &NotFoundError{} +var _ platform_errors.NotFoundError = &NotFoundError{} func NotFound(kind string) *NotFoundError { return &NotFoundError{ diff --git a/server/service/activities_test.go b/server/service/activities_test.go index d9b4344ee9..7e13eeaa78 100644 --- a/server/service/activities_test.go +++ b/server/service/activities_test.go @@ -2,62 +2,16 @@ package service import ( "context" - "encoding/json" - "log/slog" - "net/http" - "net/http/httptest" "testing" - "time" - fleetserver "github.com/fleetdm/fleet/v4/server" - "github.com/fleetdm/fleet/v4/server/activity" activity_api "github.com/fleetdm/fleet/v4/server/activity/api" - activity_bootstrap "github.com/fleetdm/fleet/v4/server/activity/bootstrap" "github.com/fleetdm/fleet/v4/server/contexts/viewer" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mock" "github.com/fleetdm/fleet/v4/server/ptr" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// webhookTestProviders implements activity.DataProviders for webhook tests. -type webhookTestProviders struct { - getWebhookConfig func() (*activity.ActivitiesWebhookSettings, error) -} - -func (p *webhookTestProviders) GetActivitiesWebhookConfig(_ context.Context) (*activity.ActivitiesWebhookSettings, error) { - return p.getWebhookConfig() -} - -func (p *webhookTestProviders) ActivateNextUpcomingActivity(_ context.Context, _ uint, _ string) error { - return nil -} - -func (p *webhookTestProviders) MaskSecretURLParams(rawURL string) string { return rawURL } -func (p *webhookTestProviders) MaskURLError(err error) error { return err } -func (p *webhookTestProviders) UsersByIDs(_ context.Context, _ []uint) ([]*activity.User, error) { - return nil, nil -} -func (p *webhookTestProviders) FindUserIDs(_ context.Context, _ string) ([]uint, error) { - return nil, nil -} -func (p *webhookTestProviders) GetHostLite(_ context.Context, _ uint) (*activity.Host, error) { - return nil, nil -} - -type ActivityTypeTest struct { - Name string `json:"name"` -} - -func (a ActivityTypeTest) ActivityName() string { - return "test_activity" -} - -func (a ActivityTypeTest) Documentation() (activity string, details string, detailsExample string) { - return "test_activity", "test_activity", "test_activity" -} - func Test_logRoleChangeActivities(t *testing.T) { tests := []struct { name string @@ -149,220 +103,6 @@ func Test_logRoleChangeActivities(t *testing.T) { } } -func TestActivityWebhooks(t *testing.T) { - ds := new(mock.Store) - opts := &TestServerOpts{} - svc, ctx := newTestService(t, ds, nil, nil, opts) - var webhookBody = fleet.ActivityWebhookPayload{} - webhookChannel := make(chan struct{}, 1) - fail429 := false - startMockServer := func(t *testing.T) string { - // create a test http server - srv := httptest.NewServer( - http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - webhookBody = fleet.ActivityWebhookPayload{} - if r.Method != "POST" { - w.WriteHeader(http.StatusMethodNotAllowed) - return // don't send the channel signal - } - switch r.URL.Path { - case "/ok": - err := json.NewDecoder(r.Body).Decode(&webhookBody) - if err != nil { - t.Log(err) - w.WriteHeader(http.StatusBadRequest) - } - case "/error": - webhookBody.Type = "error" // to check for testing - w.WriteHeader(http.StatusTeapot) - case "/429": - // Only the first request will fail - fail429 = !fail429 - if fail429 { - w.WriteHeader(http.StatusTooManyRequests) - return // don't send the channel signal - } - err := json.NewDecoder(r.Body).Decode(&webhookBody) - if err != nil { - t.Log(err) - w.WriteHeader(http.StatusBadRequest) - } - default: - w.WriteHeader(http.StatusNotFound) - return // don't send the channel signal - } - webhookChannel <- struct{}{} - }, - ), - ) - t.Cleanup(srv.Close) - return srv.URL - } - mockUrl := startMockServer(t) - testUrl := mockUrl - - // Wire a real activity bounded context service as delegate so that webhook - // firing (which lives in the bounded context) is exercised. The mock still - // captures invocations and the user for assertions. - providers := &webhookTestProviders{ - getWebhookConfig: func() (*activity.ActivitiesWebhookSettings, error) { - return &activity.ActivitiesWebhookSettings{ - Enable: true, - DestinationURL: testUrl, - }, nil - }, - } - discardLogger := slog.New(slog.DiscardHandler) - realActivitySvc := activity_bootstrap.NewForUnitTests(providers, func(ctx context.Context, url string, payload any) error { - return fleetserver.PostJSONWithTimeout(ctx, url, payload, discardLogger) - }, discardLogger) - opts.ActivityMock.Delegate = realActivitySvc - - var activityUser *activity_api.User - opts.ActivityMock.NewActivityFunc = func(_ context.Context, user *activity_api.User, _ activity_api.ActivityDetails) error { - activityUser = user - return nil - } - - tests := []struct { - name string - user *fleet.User - url string - doError bool - }{ - { - name: "nil user", - url: mockUrl + "/ok", - user: nil, - }, - { - name: "real user", - url: mockUrl + "/ok", - user: &fleet.User{ - ID: 1, - Name: "testUser", - Email: "testUser@example.com", - }, - }, - { - name: "error", - url: mockUrl + "/error", - doError: true, - }, - { - name: "429", - url: mockUrl + "/429", - user: &fleet.User{ - ID: 2, - Name: "testUser2", - Email: "testUser2@example.com", - }, - }, - } - - for _, tt := range tests { - t.Run( - tt.name, func(t *testing.T) { - opts.ActivityMock.NewActivityFuncInvoked = false - testUrl = tt.url - startTime := time.Now() - act := ActivityTypeTest{Name: tt.name} - err := svc.NewActivity(ctx, tt.user, act) - require.NoError(t, err) - select { - case <-time.After(1 * time.Second): - t.Error("timeout") - case <-webhookChannel: - if tt.doError { - assert.Equal(t, "error", webhookBody.Type) - } else { - endTime := time.Now() - assert.False( - t, webhookBody.Timestamp.Before(startTime), "timestamp %s is before start time %s", - webhookBody.Timestamp.String(), startTime.String(), - ) - assert.False(t, webhookBody.Timestamp.After(endTime)) - if tt.user == nil { - assert.Nil(t, webhookBody.ActorFullName) - assert.Nil(t, webhookBody.ActorID) - assert.Nil(t, webhookBody.ActorEmail) - } else { - require.NotNil(t, webhookBody.ActorFullName) - assert.Equal(t, tt.user.Name, *webhookBody.ActorFullName) - require.NotNil(t, webhookBody.ActorID) - assert.Equal(t, tt.user.ID, *webhookBody.ActorID) - require.NotNil(t, webhookBody.ActorEmail) - assert.Equal(t, tt.user.Email, *webhookBody.ActorEmail) - } - assert.Equal(t, act.ActivityName(), webhookBody.Type) - var details map[string]string - require.NoError(t, json.Unmarshal(*webhookBody.Details, &details)) - assert.Len(t, details, 1) - assert.Equal(t, tt.name, details["name"]) - } - } - require.True(t, opts.ActivityMock.NewActivityFuncInvoked) - if tt.user == nil { - assert.Nil(t, activityUser) - } else { - require.NotNil(t, activityUser) - assert.Equal(t, tt.user.ID, activityUser.ID) - assert.Equal(t, tt.user.Name, activityUser.Name) - assert.Equal(t, tt.user.Email, activityUser.Email) - } - }, - ) - } -} - -func TestActivityWebhooksDisabled(t *testing.T) { - ds := new(mock.Store) - opts := &TestServerOpts{} - svc, ctx := newTestService(t, ds, nil, nil, opts) - startMockServer := func(t *testing.T) string { - // create a test http server - srv := httptest.NewServer( - http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - t.Error("should not be called") - }, - ), - ) - t.Cleanup(srv.Close) - return srv.URL - } - mockUrl := startMockServer(t) - - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{ - WebhookSettings: fleet.WebhookSettings{ - ActivitiesWebhook: fleet.ActivitiesWebhookSettings{ - Enable: false, - DestinationURL: mockUrl, - }, - }, - }, nil - } - var activityUser *activity_api.User - opts.ActivityMock.NewActivityFunc = func(_ context.Context, user *activity_api.User, _ activity_api.ActivityDetails) error { - activityUser = user - return nil - } - activity := ActivityTypeTest{Name: "no webhook"} - user := &fleet.User{ - ID: 1, - Name: "testUser", - Email: "testUser@example.com", - } - require.NoError(t, svc.NewActivity(ctx, user, activity)) - require.True(t, opts.ActivityMock.NewActivityFuncInvoked) - require.NotNil(t, activityUser) - assert.Equal(t, user.ID, activityUser.ID) - assert.Equal(t, user.Name, activityUser.Name) - assert.Equal(t, user.Email, activityUser.Email) -} - func TestCancelHostUpcomingActivityAuth(t *testing.T) { ds := new(mock.Store) svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: &fleet.LicenseInfo{Tier: fleet.TierPremium}}) diff --git a/server/service/testing_utils.go b/server/service/testing_utils.go index 9552883ffe..986187c385 100644 --- a/server/service/testing_utils.go +++ b/server/service/testing_utils.go @@ -23,7 +23,6 @@ import ( "github.com/fleetdm/fleet/v4/ee/server/service/est" "github.com/fleetdm/fleet/v4/ee/server/service/hostidentity" "github.com/fleetdm/fleet/v4/ee/server/service/hostidentity/httpsig" - "github.com/fleetdm/fleet/v4/server" "github.com/fleetdm/fleet/v4/server/acl/activityacl" activity_api "github.com/fleetdm/fleet/v4/server/activity/api" activity_bootstrap "github.com/fleetdm/fleet/v4/server/activity/bootstrap" @@ -491,9 +490,6 @@ func RunServerForTestsWithServiceWithDS(t *testing.T, ctx context.Context, ds fl opts[0].DBConns, activityAuthorizer, activityACLAdapter, - func(ctx context.Context, url string, payload any) error { - return server.PostJSONWithTimeout(ctx, url, payload, slogLogger) - }, slogLogger, ) svc.SetActivityService(activitySvc) diff --git a/server/utils.go b/server/utils.go index ce843cba00..3137145cf2 100644 --- a/server/utils.go +++ b/server/utils.go @@ -1,24 +1,16 @@ package server import ( - "bytes" - "context" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/base64" - "encoding/json" "encoding/pem" "errors" "fmt" "html/template" - "io" - "log/slog" - "net/http" "strings" - "time" - "github.com/fleetdm/fleet/v4/pkg/fleethttp" "github.com/fleetdm/fleet/v4/server/bindata" platformhttp "github.com/fleetdm/fleet/v4/server/platform/http" ) @@ -45,62 +37,10 @@ func GenerateRandomURLSafeText(keySize int) (string, error) { return base64.URLEncoding.EncodeToString(key), nil } -func httpSuccessStatus(statusCode int) bool { - return statusCode >= 200 && statusCode <= 299 -} - -// errWithStatus is an error with a particular status code. -type errWithStatus struct { - err string - statusCode int -} - -// Error implements the error interface -func (e *errWithStatus) Error() string { - return e.err -} - -// StatusCode implements the StatusCoder interface for returning custom status codes. -func (e *errWithStatus) StatusCode() int { - return e.statusCode -} - -func PostJSONWithTimeout(ctx context.Context, url string, v any, logger *slog.Logger) error { - jsonBytes, err := json.Marshal(v) - if err != nil { - return err - } - - client := fleethttp.NewClient(fleethttp.WithTimeout(30 * time.Second)) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBytes)) - if err != nil { - return err - } - - req.Header.Set("Content-Type", "application/json") - - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("failed to POST to %s: %s, request-size=%d", MaskSecretURLParams(url), MaskURLError(err), len(jsonBytes)) - } - defer resp.Body.Close() - - if !httpSuccessStatus(resp.StatusCode) { - body, _ := io.ReadAll(resp.Body) - bodyStr := string(body) - if len(bodyStr) > 512 { - bodyStr = bodyStr[:512] - } - logger.DebugContext(ctx, "non-success response from POST", - "url", MaskSecretURLParams(url), - "status_code", resp.StatusCode, - "body", bodyStr, - ) - return &errWithStatus{err: fmt.Sprintf("error posting to %s", MaskSecretURLParams(url)), statusCode: resp.StatusCode} - } - - return nil -} +// PostJSONWithTimeout marshals v as JSON and POSTs it to the given URL with a 30-second timeout. +// +// Deprecated: Use github.com/fleetdm/fleet/v4/server/platform/http.PostJSONWithTimeout instead. +var PostJSONWithTimeout = platformhttp.PostJSONWithTimeout // MaskSecretURLParams masks URL query values if the query param name includes "secret", "token", // "key", "password". It accepts a raw string and returns a redacted string if the raw string is