Move PostJSONWithTimeout to platform/http package and activity cleanup (#40561)

<!-- Add the related story/sub-task/bug number, like Resolves #123, or
remove if NA -->
**Related issue:** Resolves #38536

- Moved PostJSONWithTimeout to platform/http
- Created platform/errors package with only types needed by ctxerr. This
way, ctxerr did not need to import fleethttp.
- Made activity bounded context use PostJSONWithTimeout directly
- Removed some activity types from legacy code that were no longer
needed

# Checklist for submitter

- [ ] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
- Changes file `38536-new-activity-bc` already present, and this is just
cleanup from that work.

## Testing

- [x] Added/updated automated tests
- [x] QA'd all new/changed functionality manually


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

* **Refactor**
* Reorganized error handling utilities for improved clarity and
decoupling.
* Consolidated HTTP utilities to centralize JSON posting functionality
with timeout support.
* Simplified activity service initialization by removing unused internal
parameters.
* Cleaned up test utilities and removed webhook-related test
scaffolding.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Victor Lyuboslavsky 2026-02-26 17:39:10 -06:00 committed by GitHub
parent 5b78ad3644
commit 92bc1c650e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 346 additions and 510 deletions

View file

@ -1813,9 +1813,6 @@ func createActivityBoundedContext(svc fleet.Service, dbConns *common_mysql.DBCon
dbConns,
activityAuthorizer,
activityACLAdapter,
func(ctx context.Context, url string, payload any) error {
return server.PostJSONWithTimeout(ctx, url, payload, logger)
},
logger,
)
// Create auth middleware for activity bounded context

View file

@ -28,6 +28,7 @@ var (
m + "/server/contexts/logging",
m + "/server/contexts/authz",
m + "/server/contexts/publicip",
m + "/pkg/fleethttp",
}
)
@ -67,7 +68,7 @@ func TestActivityPackageDependencies(t *testing.T) {
m + "/server/activity/api",
m + "/server/activity/internal/types",
m + "/server/activity/internal/testutils",
m + "/server/platform/http",
m + "/server/platform/errors",
m + "/server/platform/logging",
m + "/server/platform/mysql",
m + "/server/platform/mysql/testing_utils",

View file

@ -20,11 +20,10 @@ func New(
dbConns *platform_mysql.DBConnections,
authorizer platform_authz.Authorizer,
providers activity.DataProviders,
webhookSendFn activity.WebhookSendFunc,
logger *slog.Logger,
) (api.Service, func(authMiddleware endpoint.Middleware) eu.HandlerRoutesFunc) {
ds := mysql.NewDatastore(dbConns, logger)
svc := service.NewService(authorizer, ds, providers, webhookSendFn, logger)
svc := service.NewService(authorizer, ds, providers, logger)
routesFn := func(authMiddleware endpoint.Middleware) eu.HandlerRoutesFunc {
return service.GetRoutes(svc, authMiddleware)

View file

@ -1,48 +0,0 @@
package bootstrap
import (
"context"
"log/slog"
"time"
"github.com/fleetdm/fleet/v4/server/activity"
"github.com/fleetdm/fleet/v4/server/activity/api"
"github.com/fleetdm/fleet/v4/server/activity/internal/service"
"github.com/fleetdm/fleet/v4/server/activity/internal/types"
platform_authz "github.com/fleetdm/fleet/v4/server/platform/authz"
)
// NewForUnitTests creates an activity NewActivityService backed by a noop store (no database required).
func NewForUnitTests(
providers activity.DataProviders,
webhookSendFn activity.WebhookSendFunc,
logger *slog.Logger,
) api.NewActivityService {
return service.NewService(&noopAuthorizer{}, &noopStore{}, providers, webhookSendFn, logger)
}
// noopAuthorizer allows all actions (appropriate for unit tests).
type noopAuthorizer struct{}
func (a *noopAuthorizer) Authorize(_ context.Context, _ platform_authz.AuthzTyper, _ platform_authz.Action) error {
return nil
}
// noopStore is a datastore that does nothing (appropriate for unit tests that only need webhook behavior).
type noopStore struct{}
func (s *noopStore) ListActivities(_ context.Context, _ types.ListOptions) ([]*api.Activity, *api.PaginationMetadata, error) {
return nil, nil, nil
}
func (s *noopStore) ListHostPastActivities(_ context.Context, _ uint, _ types.ListOptions) ([]*api.Activity, *api.PaginationMetadata, error) {
return nil, nil, nil
}
func (s *noopStore) MarkActivitiesAsStreamed(_ context.Context, _ []uint) error {
return nil
}
func (s *noopStore) NewActivity(_ context.Context, _ *api.User, _ api.ActivityDetails, _ []byte, _ time.Time) error {
return nil
}

View file

@ -118,7 +118,7 @@ func (s *Service) fireActivityWebhook(
retryStrategy.MaxElapsedTime = 30 * time.Minute
err := backoff.Retry(
func() error {
if err := s.webhookSendFn(
if err := platformhttp.PostJSONWithTimeout(
spanCtx, webhookURL, &webhookPayload{
Timestamp: timestamp,
ActorFullName: userName,
@ -126,7 +126,7 @@ func (s *Service) fireActivityWebhook(
ActorEmail: userEmail,
Type: activityType,
Details: (*json.RawMessage)(&detailsBytes),
},
}, s.logger,
); err != nil {
var statusCoder kithttp.StatusCoder
if errors.As(err, &statusCoder) && statusCoder.StatusCode() == http.StatusTooManyRequests {

View file

@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
@ -78,8 +80,7 @@ func (a activatorActivity) ActivateNextUpcomingActivityArgs() (uint, string) {
}
func newTestService(ds types.Datastore, providers activity.DataProviders) *Service {
noopWebhookSend := func(_ context.Context, _ string, _ any) error { return nil }
return NewService(&mockAuthorizer{}, ds, providers, noopWebhookSend, slog.New(slog.DiscardHandler))
return NewService(&mockAuthorizer{}, ds, providers, slog.New(slog.DiscardHandler))
}
func TestNewActivityStoresWithWebhookContextKey(t *testing.T) {
@ -213,3 +214,183 @@ func TestNewActivityNilUser(t *testing.T) {
require.True(t, ds.newActivityCalled)
assert.Nil(t, ds.lastUser)
}
// newTestServiceWithWebhook creates a service configured for webhook delivery tests.
func newTestServiceWithWebhook(ds types.Datastore, providers activity.DataProviders) *Service {
return NewService(&mockAuthorizer{}, ds, providers, slog.New(slog.DiscardHandler))
}
func TestNewActivityWebhook(t *testing.T) {
t.Parallel()
webhookChannel := make(chan struct{}, 1)
var webhookBody webhookPayload
fail429 := false
startMockServer := func(t *testing.T) string {
srv := httptest.NewServer(
http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
webhookBody = webhookPayload{}
if r.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
switch r.URL.Path {
case "/ok":
err := json.NewDecoder(r.Body).Decode(&webhookBody)
if err != nil {
t.Log(err)
w.WriteHeader(http.StatusBadRequest)
}
case "/error":
webhookBody.Type = "error"
w.WriteHeader(http.StatusTeapot)
case "/429":
fail429 = !fail429
if fail429 {
w.WriteHeader(http.StatusTooManyRequests)
return
}
err := json.NewDecoder(r.Body).Decode(&webhookBody)
if err != nil {
t.Log(err)
w.WriteHeader(http.StatusBadRequest)
}
default:
w.WriteHeader(http.StatusNotFound)
return
}
webhookChannel <- struct{}{}
},
),
)
t.Cleanup(srv.Close)
return srv.URL
}
mockURL := startMockServer(t)
testURL := mockURL
ds := &newActivityMockDatastore{}
providers := &newActivityMockProviders{
mockDataProviders: mockDataProviders{
mockUserProvider: &mockUserProvider{},
mockHostProvider: &mockHostProvider{},
webhookConfig: &activity.ActivitiesWebhookSettings{
Enable: true,
DestinationURL: testURL,
},
},
}
svc := newTestServiceWithWebhook(ds, providers)
tests := []struct {
name string
user *api.User
url string
doError bool
}{
{
name: "nil user",
url: mockURL + "/ok",
user: nil,
},
{
name: "real user",
url: mockURL + "/ok",
user: &api.User{
ID: 1,
Name: "testUser",
Email: "testUser@example.com",
},
},
{
name: "error",
url: mockURL + "/error",
doError: true,
},
{
name: "429",
url: mockURL + "/429",
user: &api.User{
ID: 2,
Name: "testUserRetry",
Email: "testUserRetry@example.com",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ds.newActivityCalled = false
providers.webhookConfig.DestinationURL = tt.url
startTime := time.Now()
act := simpleActivity{Name: tt.name}
err := svc.NewActivity(t.Context(), tt.user, act)
require.NoError(t, err)
select {
case <-time.After(3 * time.Second):
t.Error("timeout waiting for webhook")
case <-webhookChannel:
if tt.doError {
assert.Equal(t, "error", webhookBody.Type)
} else {
endTime := time.Now()
assert.False(
t, webhookBody.Timestamp.Before(startTime), "timestamp %s is before start time %s",
webhookBody.Timestamp.String(), startTime.String(),
)
assert.False(t, webhookBody.Timestamp.After(endTime))
if tt.user == nil {
assert.Nil(t, webhookBody.ActorFullName)
assert.Nil(t, webhookBody.ActorID)
assert.Nil(t, webhookBody.ActorEmail)
} else {
require.NotNil(t, webhookBody.ActorFullName)
assert.Equal(t, tt.user.Name, *webhookBody.ActorFullName)
require.NotNil(t, webhookBody.ActorID)
assert.Equal(t, tt.user.ID, *webhookBody.ActorID)
require.NotNil(t, webhookBody.ActorEmail)
assert.Equal(t, tt.user.Email, *webhookBody.ActorEmail)
}
assert.Equal(t, act.ActivityName(), webhookBody.Type)
var details map[string]string
require.NoError(t, json.Unmarshal(*webhookBody.Details, &details))
assert.Len(t, details, 1)
assert.Equal(t, tt.name, details["name"])
}
}
require.True(t, ds.newActivityCalled)
})
}
}
func TestNewActivityWebhookDisabled(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("webhook server should not be called when webhook is disabled")
}),
)
t.Cleanup(srv.Close)
ds := &newActivityMockDatastore{}
providers := &newActivityMockProviders{
mockDataProviders: mockDataProviders{
mockUserProvider: &mockUserProvider{},
mockHostProvider: &mockHostProvider{},
webhookConfig: &activity.ActivitiesWebhookSettings{
Enable: false,
DestinationURL: srv.URL,
},
},
}
svc := newTestServiceWithWebhook(ds, providers)
err := svc.NewActivity(t.Context(), &api.User{ID: 1}, simpleActivity{Name: "no webhook"})
require.NoError(t, err)
require.True(t, ds.newActivityCalled)
}

View file

@ -46,11 +46,10 @@ func applyListOptionsDefaults(opt *api.ListOptions, defaultOrderKey string) {
// Service is the activity bounded context service implementation.
type Service struct {
authz platform_authz.Authorizer
store types.Datastore
providers activity.DataProviders
webhookSendFn activity.WebhookSendFunc
logger *slog.Logger
authz platform_authz.Authorizer
store types.Datastore
providers activity.DataProviders
logger *slog.Logger
}
// NewService creates a new activity service.
@ -58,15 +57,13 @@ func NewService(
authz platform_authz.Authorizer,
store types.Datastore,
providers activity.DataProviders,
webhookSendFn activity.WebhookSendFunc,
logger *slog.Logger,
) *Service {
return &Service{
authz: authz,
store: store,
providers: providers,
webhookSendFn: webhookSendFn,
logger: logger,
authz: authz,
store: store,
providers: providers,
logger: logger,
}
}

View file

@ -130,8 +130,7 @@ func setupTest(opts ...func(*testSetup)) *testSetup {
for _, opt := range opts {
opt(ts)
}
noopWebhookSend := func(_ context.Context, _ string, _ any) error { return nil }
ts.svc = NewService(ts.authz, ts.ds, ts.providers, noopWebhookSend, slog.New(slog.DiscardHandler))
ts.svc = NewService(ts.authz, ts.ds, ts.providers, slog.New(slog.DiscardHandler))
return ts
}
@ -532,9 +531,8 @@ func newTestActivity(id uint, actorName string, actorID uint, actType, details s
func TestStreamActivities(t *testing.T) {
t.Parallel()
noopWebhookSend := func(_ context.Context, _ string, _ any) error { return nil }
newStreamingService := func(ds *mockStreamingDatastore) *Service {
return NewService(&mockAuthorizer{}, ds, &mockDataProviders{mockUserProvider: &mockUserProvider{}, mockHostProvider: &mockHostProvider{}}, noopWebhookSend, slog.New(slog.DiscardHandler))
return NewService(&mockAuthorizer{}, ds, &mockDataProviders{mockUserProvider: &mockUserProvider{}, mockHostProvider: &mockHostProvider{}}, slog.New(slog.DiscardHandler))
}
t.Run("basic streaming", func(t *testing.T) {

View file

@ -1,7 +1,6 @@
package tests
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
@ -38,8 +37,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSuite {
providers := newMockDataProviders()
// Create service
noopWebhookSend := func(_ context.Context, _ string, _ any) error { return nil }
svc := service.NewService(authorizer, ds, providers, noopWebhookSend, tdb.Logger)
svc := service.NewService(authorizer, ds, providers, tdb.Logger)
// Create router with routes
router := mux.NewRouter()

View file

@ -10,9 +10,6 @@ type UpcomingActivityActivator interface {
ActivateNextUpcomingActivity(ctx context.Context, hostID uint, fromCompletedExecID string) error
}
// WebhookSendFunc is the function signature for sending a JSON payload to a URL.
type WebhookSendFunc = func(ctx context.Context, url string, payload any) error
// DataProviders combines all external dependency interfaces for the activity
// bounded context. The ACL adapter implements this single interface.
type DataProviders interface {

View file

@ -21,7 +21,7 @@ import (
"strings"
"time"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
platform_errors "github.com/fleetdm/fleet/v4/server/platform/errors"
"github.com/getsentry/sentry-go"
"go.elastic.co/apm/v2"
"go.opentelemetry.io/otel/attribute"
@ -202,7 +202,7 @@ func Wrapf(ctx context.Context, cause error, format string, args ...interface{})
// Cause returns the root error in err's chain.
func Cause(err error) error {
return platform_http.Cause(err)
return platform_errors.Cause(err)
}
// FleetCause is similar to Cause, but returns the root-most
@ -407,7 +407,7 @@ func isClientError(err error) bool {
// Check for explicit client error interface. All 4xx error types
// (not found, already exists, conflict, validation, permission,
// bad request, foreign key, etc.) should implement this interface.
var clientErr platform_http.ErrWithIsClientError
var clientErr platform_errors.ErrWithIsClientError
if errors.As(err, &clientErr) {
return clientErr.IsClientError()
}

View file

@ -8,6 +8,7 @@ import (
"sync"
"time"
platform_errors "github.com/fleetdm/fleet/v4/server/platform/errors"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
kithttp "github.com/go-kit/kit/transport/http"
)
@ -233,7 +234,7 @@ func (l *LoggingContext) setLevelError() bool {
}
if len(l.Errs) == 1 {
var ew platform_http.ErrWithIsClientError
var ew platform_errors.ErrWithIsClientError
if errors.As(l.Errs[0], &ew) && ew.IsClientError() {
return false
}

View file

@ -1736,7 +1736,7 @@ func testSoftwareTitleDisplayNameInHouse(t *testing.T, ds *Datastore) {
}
func testInHouseAppsCancelledOnUnenroll(t *testing.T, ds *Datastore) {
ctx := context.WithValue(context.Background(), fleet.ActivityWebhookContextKey, true)
ctx := t.Context()
test.CreateInsertGlobalVPPToken(t, ds)
user := test.NewUser(t, ds, "Alice", "alice@example.com", true)

View file

@ -26,7 +26,6 @@ import (
"time"
"github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/acl/activityacl"
activity_api "github.com/fleetdm/fleet/v4/server/activity/api"
activity_bootstrap "github.com/fleetdm/fleet/v4/server/activity/bootstrap"
@ -1033,9 +1032,7 @@ func NewTestActivityService(t testing.TB, ds *Datastore) activity_api.Service {
// Create service via bootstrap (the public API for creating the bounded context)
discardLogger := slog.New(slog.DiscardHandler)
svc, _ := activity_bootstrap.New(dbConns, &testingAuthorizer{}, aclAdapter, func(ctx context.Context, url string, payload any) error {
return server.PostJSONWithTimeout(ctx, url, payload, discardLogger)
}, discardLogger)
svc, _ := activity_bootstrap.New(dbConns, &testingAuthorizer{}, aclAdapter, discardLogger)
return svc
}

View file

@ -6,23 +6,9 @@ import (
"time"
)
type ContextKey string
// NewActivityFunc is the function signature for creating a new activity.
type NewActivityFunc func(ctx context.Context, user *User, activity ActivityDetails) error
type ActivityWebhookPayload struct {
Timestamp time.Time `json:"timestamp"`
ActorFullName *string `json:"actor_full_name"`
ActorID *uint `json:"actor_id"`
ActorEmail *string `json:"actor_email"`
Type string `json:"type"`
Details *json.RawMessage `json:"details"`
}
// ActivityWebhookContextKey is the context key to indicate that the activity webhook has been processed before saving the activity.
const ActivityWebhookContextKey = ContextKey("ActivityWebhook")
type Activity struct {
CreateTimestamp
@ -263,38 +249,6 @@ type ActivityDetails interface {
ActivityName() string
}
// ActivityHosts is the optional additional interface that can be implemented
// by activities that are related to hosts.
type ActivityHosts interface {
ActivityDetails
HostIDs() []uint
}
// AutomatableActivity is the optional additional interface that can be implemented
// by activities that are sometimes the result of automation ("Fleet did X"), starting with
// install/script run policy automations
type AutomatableActivity interface {
ActivityDetails
WasFromAutomation() bool
}
// ActivityHostOnly is the optional additional interface that can be implemented by activities that
// we want to exclude from the global activity feed, and only show on the Hosts details page
type ActivityHostOnly interface {
ActivityDetails
HostOnly() bool
}
// ActivityActivator is the optional additional interface that can be implemented by activities that
// may require activating the next upcoming activity when it gets created. Most upcoming activities get
// activated when the result of the previous one completes (such as scripts and software installs), but
// some can only be activated when the activity gets recorded (such as VPP and in-house apps).
type ActivityActivator interface {
ActivityDetails
MustActivateNextUpcomingActivity() bool
ActivateNextUpcomingActivityArgs() (hostID uint, cmdUUID string)
}
type ActivityTypeEnabledActivityAutomations struct {
WebhookUrl string `json:"webhook_url"`
}

View file

@ -18,6 +18,7 @@ import (
"github.com/fleetdm/fleet/v4/server/mdm/nanodep/godep"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/storage"
platform_errors "github.com/fleetdm/fleet/v4/server/platform/errors"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
"github.com/jmoiron/sqlx"
)
@ -2916,11 +2917,11 @@ const (
// same in both (the other is currently NotFound), and ideally we'd just have
// one of those interfaces.
// NotFoundError is an alias for platform_http.NotFoundError.
type NotFoundError = platform_http.NotFoundError
// NotFoundError is an alias for platform_errors.NotFoundError.
type NotFoundError = platform_errors.NotFoundError
// IsNotFound is an alias for platform_http.IsNotFound.
var IsNotFound = platform_http.IsNotFound
// IsNotFound is an alias for platform_errors.IsNotFound.
var IsNotFound = platform_errors.IsNotFound
// AlreadyExistsError is an alias for platform_http.AlreadyExistsError.
type AlreadyExistsError = platform_http.AlreadyExistsError

View file

@ -6,6 +6,7 @@ import (
"net/http"
"time"
platform_errors "github.com/fleetdm/fleet/v4/server/platform/errors"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
"github.com/rs/zerolog"
)
@ -51,8 +52,8 @@ type ErrWithLogFields = platform_http.ErrWithLogFields
// ErrWithRetryAfter is an alias for platform_http.ErrWithRetryAfter.
type ErrWithRetryAfter = platform_http.ErrWithRetryAfter
// ErrWithIsClientError is an alias for platform_http.ErrWithIsClientError.
type ErrWithIsClientError = platform_http.ErrWithIsClientError
// ErrWithIsClientError is an alias for platform_errors.ErrWithIsClientError.
type ErrWithIsClientError = platform_errors.ErrWithIsClientError
type invalidArgWithStatusError struct {
InvalidArgumentError
@ -400,7 +401,7 @@ func GetJSONUnknownField(err error) *string {
}
// Cause returns the root error in err's chain.
var Cause = platform_http.Cause
var Cause = platform_errors.Cause
// FleetdError is an error that can be reported by any of the fleetd
// components.

View file

@ -21,6 +21,7 @@ func TestPlatformPackageDependencies(t *testing.T) {
// Platform packages can depend on each other
m+"/server/platform...",
// Infra packages
m+"/pkg/fleethttp",
m+"/server/contexts/authz",
m+"/server/contexts/ctxerr",
m+"/server/contexts/license",
@ -41,7 +42,8 @@ func TestEndpointerPackageDependencies(t *testing.T) {
IgnoreDeps(
// Platform packages
m+"/server/platform...",
// Other infra packages
// Infra packages
m+"/pkg/fleethttp",
m+"/server/contexts/authz",
m+"/server/contexts/ctxerr",
m+"/server/contexts/license",
@ -56,7 +58,11 @@ func TestHTTPPackageDependencies(t *testing.T) {
archtest.NewPackageTest(t, m+"/server/platform/http").
OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)).
WithTests().
ShouldNotDependOn(m + "/...").
ShouldNotDependOn(m+"/...").
IgnoreDeps(
m+"/pkg/fleethttp",
m+"/server/platform/errors",
).
Check()
}
@ -65,10 +71,13 @@ func TestAuthzCheckPackageDependencies(t *testing.T) {
archtest.NewPackageTest(t, m+"/server/platform/middleware/authzcheck").
OnlyInclude(regexp.MustCompile(`^github\.com/fleetdm/`)).
WithTests().
ShouldNotDependOn(m+"/...").
IgnoreDeps(
// Platform packages
m+"/server/platform/errors",
m+"/server/platform/http",
// Other infra packages
m+"/pkg/fleethttp",
m+"/server/contexts/authz",
).
Check()
@ -82,6 +91,7 @@ func TestRatelimitPackageDependencies(t *testing.T) {
ShouldNotDependOn(m+"/...").
IgnoreDeps(
// Platform packages
m+"/server/platform/errors",
m+"/server/platform/http",
// Other infra packages
m+"/server/contexts/authz",
@ -103,7 +113,7 @@ func TestMysqlPackageDependencies(t *testing.T) {
// Ignore our own packages
m+"/server/platform/mysql...",
// Other infra packages
m+"/server/platform/http",
m+"/server/platform/errors",
m+"/server/platform/logging",
m+"/server/contexts/ctxerr",
).

View file

@ -0,0 +1,41 @@
// Package errors provides error classification primitives used across the
// codebase. These are intentionally kept free of HTTP or other transport
// dependencies so that low-level packages (datastores, context helpers) can
// use them without pulling in higher-level concerns.
package errors
import "errors"
// Cause returns the root error in err's chain.
func Cause(err error) error {
for {
uerr := errors.Unwrap(err)
if uerr == nil {
return err
}
err = uerr
}
}
// ErrWithIsClientError is an interface for errors that explicitly specify
// whether they are client errors or not. By default, errors are treated as
// server errors.
type ErrWithIsClientError interface {
error
IsClientError() bool
}
// NotFoundError is an interface for errors when a resource cannot be found.
type NotFoundError interface {
error
IsNotFound() bool
}
// IsNotFound returns true if err is a not-found error.
func IsNotFound(err error) bool {
var nfe NotFoundError
if errors.As(err, &nfe) {
return nfe.IsNotFound()
}
return false
}

View file

@ -11,6 +11,7 @@ import (
"strings"
"github.com/docker/go-units"
platform_errors "github.com/fleetdm/fleet/v4/server/platform/errors"
"github.com/google/uuid"
)
@ -174,7 +175,7 @@ func IsJSONUnknownFieldError(err error) bool {
// GetJSONUnknownField returns the unknown field name from a JSON unknown field error.
func GetJSONUnknownField(err error) *string {
errCause := Cause(err)
errCause := platform_errors.Cause(err)
if IsJSONUnknownFieldError(errCause) {
substr := rxJSONUnknownField.FindStringSubmatch(errCause.Error())
return &substr[1]
@ -186,7 +187,7 @@ func GetJSONUnknownField(err error) *string {
// root cause is one of the supported types, otherwise it returns the error
// message.
func (e UserMessageError) UserMessage() string {
cause := Cause(e.error)
cause := platform_errors.Cause(e.error)
switch cause := cause.(type) {
case *json.UnmarshalTypeError:
var sb strings.Builder
@ -213,17 +214,6 @@ func (e UserMessageError) UserMessage() string {
}
}
// Cause returns the root error in err's chain.
func Cause(err error) error {
for {
uerr := errors.Unwrap(err)
if uerr == nil {
return err
}
err = uerr
}
}
// ErrWithRetryAfter is an interface for errors that should set a specific HTTP
// Header Retry-After value (see
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After)
@ -248,21 +238,6 @@ func IsForeignKey(err error) bool {
return false
}
// NotFoundError is an interface for errors when a resource cannot be found.
type NotFoundError interface {
error
IsNotFound() bool
}
// IsNotFound returns true if err is a not-found error.
func IsNotFound(err error) bool {
var nfe NotFoundError
if errors.As(err, &nfe) {
return nfe.IsNotFound()
}
return false
}
// AlreadyExistsError is an interface for errors when a resource already exists.
type AlreadyExistsError interface {
error
@ -282,14 +257,6 @@ func (e *Error) Error() string {
return e.Message
}
// ErrWithIsClientError is an interface for errors that explicitly specify
// whether they are client errors or not. By default, errors are treated as
// server errors.
type ErrWithIsClientError interface {
error
IsClientError() bool
}
// AuthFailedError is returned when authentication fails.
type AuthFailedError struct {
// internal is the reason that should only be logged internally

View file

@ -0,0 +1,68 @@
package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
)
// errWithStatus is an error with a particular status code.
type errWithStatus struct {
err string
statusCode int
}
// Error implements the error interface.
func (e *errWithStatus) Error() string {
return e.err
}
// StatusCode implements the StatusCoder interface for returning custom status codes.
func (e *errWithStatus) StatusCode() int {
return e.statusCode
}
// PostJSONWithTimeout marshals v as JSON and POSTs it to the given URL with a 30-second timeout.
func PostJSONWithTimeout(ctx context.Context, url string, v any, logger *slog.Logger) error {
jsonBytes, err := json.Marshal(v)
if err != nil {
return err
}
client := fleethttp.NewClient(fleethttp.WithTimeout(30 * time.Second))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBytes))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to POST to %s: %s, request-size=%d", MaskSecretURLParams(url), MaskURLError(err), len(jsonBytes))
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode > 299 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 513))
bodyStr := string(body)
if len(bodyStr) > 512 {
bodyStr = bodyStr[:512]
}
logger.DebugContext(ctx, "non-success response from POST",
"url", MaskSecretURLParams(url),
"status_code", resp.StatusCode,
"body", bodyStr,
)
return &errWithStatus{err: fmt.Sprintf("error posting to %s", MaskSecretURLParams(url)), statusCode: resp.StatusCode}
}
return nil
}

View file

@ -6,7 +6,7 @@ import (
"fmt"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
platform_errors "github.com/fleetdm/fleet/v4/server/platform/errors"
"github.com/go-sql-driver/mysql"
)
@ -18,7 +18,7 @@ type NotFoundError struct {
}
// Compile-time interface check.
var _ platform_http.NotFoundError = &NotFoundError{}
var _ platform_errors.NotFoundError = &NotFoundError{}
func NotFound(kind string) *NotFoundError {
return &NotFoundError{

View file

@ -2,62 +2,16 @@ package service
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
fleetserver "github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/activity"
activity_api "github.com/fleetdm/fleet/v4/server/activity/api"
activity_bootstrap "github.com/fleetdm/fleet/v4/server/activity/bootstrap"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// webhookTestProviders implements activity.DataProviders for webhook tests.
type webhookTestProviders struct {
getWebhookConfig func() (*activity.ActivitiesWebhookSettings, error)
}
func (p *webhookTestProviders) GetActivitiesWebhookConfig(_ context.Context) (*activity.ActivitiesWebhookSettings, error) {
return p.getWebhookConfig()
}
func (p *webhookTestProviders) ActivateNextUpcomingActivity(_ context.Context, _ uint, _ string) error {
return nil
}
func (p *webhookTestProviders) MaskSecretURLParams(rawURL string) string { return rawURL }
func (p *webhookTestProviders) MaskURLError(err error) error { return err }
func (p *webhookTestProviders) UsersByIDs(_ context.Context, _ []uint) ([]*activity.User, error) {
return nil, nil
}
func (p *webhookTestProviders) FindUserIDs(_ context.Context, _ string) ([]uint, error) {
return nil, nil
}
func (p *webhookTestProviders) GetHostLite(_ context.Context, _ uint) (*activity.Host, error) {
return nil, nil
}
type ActivityTypeTest struct {
Name string `json:"name"`
}
func (a ActivityTypeTest) ActivityName() string {
return "test_activity"
}
func (a ActivityTypeTest) Documentation() (activity string, details string, detailsExample string) {
return "test_activity", "test_activity", "test_activity"
}
func Test_logRoleChangeActivities(t *testing.T) {
tests := []struct {
name string
@ -149,220 +103,6 @@ func Test_logRoleChangeActivities(t *testing.T) {
}
}
func TestActivityWebhooks(t *testing.T) {
ds := new(mock.Store)
opts := &TestServerOpts{}
svc, ctx := newTestService(t, ds, nil, nil, opts)
var webhookBody = fleet.ActivityWebhookPayload{}
webhookChannel := make(chan struct{}, 1)
fail429 := false
startMockServer := func(t *testing.T) string {
// create a test http server
srv := httptest.NewServer(
http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
webhookBody = fleet.ActivityWebhookPayload{}
if r.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
return // don't send the channel signal
}
switch r.URL.Path {
case "/ok":
err := json.NewDecoder(r.Body).Decode(&webhookBody)
if err != nil {
t.Log(err)
w.WriteHeader(http.StatusBadRequest)
}
case "/error":
webhookBody.Type = "error" // to check for testing
w.WriteHeader(http.StatusTeapot)
case "/429":
// Only the first request will fail
fail429 = !fail429
if fail429 {
w.WriteHeader(http.StatusTooManyRequests)
return // don't send the channel signal
}
err := json.NewDecoder(r.Body).Decode(&webhookBody)
if err != nil {
t.Log(err)
w.WriteHeader(http.StatusBadRequest)
}
default:
w.WriteHeader(http.StatusNotFound)
return // don't send the channel signal
}
webhookChannel <- struct{}{}
},
),
)
t.Cleanup(srv.Close)
return srv.URL
}
mockUrl := startMockServer(t)
testUrl := mockUrl
// Wire a real activity bounded context service as delegate so that webhook
// firing (which lives in the bounded context) is exercised. The mock still
// captures invocations and the user for assertions.
providers := &webhookTestProviders{
getWebhookConfig: func() (*activity.ActivitiesWebhookSettings, error) {
return &activity.ActivitiesWebhookSettings{
Enable: true,
DestinationURL: testUrl,
}, nil
},
}
discardLogger := slog.New(slog.DiscardHandler)
realActivitySvc := activity_bootstrap.NewForUnitTests(providers, func(ctx context.Context, url string, payload any) error {
return fleetserver.PostJSONWithTimeout(ctx, url, payload, discardLogger)
}, discardLogger)
opts.ActivityMock.Delegate = realActivitySvc
var activityUser *activity_api.User
opts.ActivityMock.NewActivityFunc = func(_ context.Context, user *activity_api.User, _ activity_api.ActivityDetails) error {
activityUser = user
return nil
}
tests := []struct {
name string
user *fleet.User
url string
doError bool
}{
{
name: "nil user",
url: mockUrl + "/ok",
user: nil,
},
{
name: "real user",
url: mockUrl + "/ok",
user: &fleet.User{
ID: 1,
Name: "testUser",
Email: "testUser@example.com",
},
},
{
name: "error",
url: mockUrl + "/error",
doError: true,
},
{
name: "429",
url: mockUrl + "/429",
user: &fleet.User{
ID: 2,
Name: "testUser2",
Email: "testUser2@example.com",
},
},
}
for _, tt := range tests {
t.Run(
tt.name, func(t *testing.T) {
opts.ActivityMock.NewActivityFuncInvoked = false
testUrl = tt.url
startTime := time.Now()
act := ActivityTypeTest{Name: tt.name}
err := svc.NewActivity(ctx, tt.user, act)
require.NoError(t, err)
select {
case <-time.After(1 * time.Second):
t.Error("timeout")
case <-webhookChannel:
if tt.doError {
assert.Equal(t, "error", webhookBody.Type)
} else {
endTime := time.Now()
assert.False(
t, webhookBody.Timestamp.Before(startTime), "timestamp %s is before start time %s",
webhookBody.Timestamp.String(), startTime.String(),
)
assert.False(t, webhookBody.Timestamp.After(endTime))
if tt.user == nil {
assert.Nil(t, webhookBody.ActorFullName)
assert.Nil(t, webhookBody.ActorID)
assert.Nil(t, webhookBody.ActorEmail)
} else {
require.NotNil(t, webhookBody.ActorFullName)
assert.Equal(t, tt.user.Name, *webhookBody.ActorFullName)
require.NotNil(t, webhookBody.ActorID)
assert.Equal(t, tt.user.ID, *webhookBody.ActorID)
require.NotNil(t, webhookBody.ActorEmail)
assert.Equal(t, tt.user.Email, *webhookBody.ActorEmail)
}
assert.Equal(t, act.ActivityName(), webhookBody.Type)
var details map[string]string
require.NoError(t, json.Unmarshal(*webhookBody.Details, &details))
assert.Len(t, details, 1)
assert.Equal(t, tt.name, details["name"])
}
}
require.True(t, opts.ActivityMock.NewActivityFuncInvoked)
if tt.user == nil {
assert.Nil(t, activityUser)
} else {
require.NotNil(t, activityUser)
assert.Equal(t, tt.user.ID, activityUser.ID)
assert.Equal(t, tt.user.Name, activityUser.Name)
assert.Equal(t, tt.user.Email, activityUser.Email)
}
},
)
}
}
func TestActivityWebhooksDisabled(t *testing.T) {
ds := new(mock.Store)
opts := &TestServerOpts{}
svc, ctx := newTestService(t, ds, nil, nil, opts)
startMockServer := func(t *testing.T) string {
// create a test http server
srv := httptest.NewServer(
http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
t.Error("should not be called")
},
),
)
t.Cleanup(srv.Close)
return srv.URL
}
mockUrl := startMockServer(t)
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{
WebhookSettings: fleet.WebhookSettings{
ActivitiesWebhook: fleet.ActivitiesWebhookSettings{
Enable: false,
DestinationURL: mockUrl,
},
},
}, nil
}
var activityUser *activity_api.User
opts.ActivityMock.NewActivityFunc = func(_ context.Context, user *activity_api.User, _ activity_api.ActivityDetails) error {
activityUser = user
return nil
}
activity := ActivityTypeTest{Name: "no webhook"}
user := &fleet.User{
ID: 1,
Name: "testUser",
Email: "testUser@example.com",
}
require.NoError(t, svc.NewActivity(ctx, user, activity))
require.True(t, opts.ActivityMock.NewActivityFuncInvoked)
require.NotNil(t, activityUser)
assert.Equal(t, user.ID, activityUser.ID)
assert.Equal(t, user.Name, activityUser.Name)
assert.Equal(t, user.Email, activityUser.Email)
}
func TestCancelHostUpcomingActivityAuth(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: &fleet.LicenseInfo{Tier: fleet.TierPremium}})

View file

@ -23,7 +23,6 @@ import (
"github.com/fleetdm/fleet/v4/ee/server/service/est"
"github.com/fleetdm/fleet/v4/ee/server/service/hostidentity"
"github.com/fleetdm/fleet/v4/ee/server/service/hostidentity/httpsig"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/acl/activityacl"
activity_api "github.com/fleetdm/fleet/v4/server/activity/api"
activity_bootstrap "github.com/fleetdm/fleet/v4/server/activity/bootstrap"
@ -491,9 +490,6 @@ func RunServerForTestsWithServiceWithDS(t *testing.T, ctx context.Context, ds fl
opts[0].DBConns,
activityAuthorizer,
activityACLAdapter,
func(ctx context.Context, url string, payload any) error {
return server.PostJSONWithTimeout(ctx, url, payload, slogLogger)
},
slogLogger,
)
svc.SetActivityService(activitySvc)

View file

@ -1,24 +1,16 @@
package server
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"html/template"
"io"
"log/slog"
"net/http"
"strings"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/bindata"
platformhttp "github.com/fleetdm/fleet/v4/server/platform/http"
)
@ -45,62 +37,10 @@ func GenerateRandomURLSafeText(keySize int) (string, error) {
return base64.URLEncoding.EncodeToString(key), nil
}
func httpSuccessStatus(statusCode int) bool {
return statusCode >= 200 && statusCode <= 299
}
// errWithStatus is an error with a particular status code.
type errWithStatus struct {
err string
statusCode int
}
// Error implements the error interface
func (e *errWithStatus) Error() string {
return e.err
}
// StatusCode implements the StatusCoder interface for returning custom status codes.
func (e *errWithStatus) StatusCode() int {
return e.statusCode
}
func PostJSONWithTimeout(ctx context.Context, url string, v any, logger *slog.Logger) error {
jsonBytes, err := json.Marshal(v)
if err != nil {
return err
}
client := fleethttp.NewClient(fleethttp.WithTimeout(30 * time.Second))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBytes))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to POST to %s: %s, request-size=%d", MaskSecretURLParams(url), MaskURLError(err), len(jsonBytes))
}
defer resp.Body.Close()
if !httpSuccessStatus(resp.StatusCode) {
body, _ := io.ReadAll(resp.Body)
bodyStr := string(body)
if len(bodyStr) > 512 {
bodyStr = bodyStr[:512]
}
logger.DebugContext(ctx, "non-success response from POST",
"url", MaskSecretURLParams(url),
"status_code", resp.StatusCode,
"body", bodyStr,
)
return &errWithStatus{err: fmt.Sprintf("error posting to %s", MaskSecretURLParams(url)), statusCode: resp.StatusCode}
}
return nil
}
// PostJSONWithTimeout marshals v as JSON and POSTs it to the given URL with a 30-second timeout.
//
// Deprecated: Use github.com/fleetdm/fleet/v4/server/platform/http.PostJSONWithTimeout instead.
var PostJSONWithTimeout = platformhttp.PostJSONWithTimeout
// MaskSecretURLParams masks URL query values if the query param name includes "secret", "token",
// "key", "password". It accepts a raw string and returns a redacted string if the raw string is