mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 21:47:20 +00:00
- [X] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. See [Changes files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files) for more information. ## Testing - [X] Added/updated automated tests - [X] QA'd all new/changed functionality manually --------- Co-authored-by: Juan Fernandez <juan-fdz-hawa@users.noreply.github.com>
473 lines
16 KiB
Go
473 lines
16 KiB
Go
package endpointer
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
|
|
authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz"
|
|
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
|
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
|
"github.com/go-kit/kit/endpoint"
|
|
kithttp "github.com/go-kit/kit/transport/http"
|
|
"github.com/gorilla/mux"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// testHandlerFunc is a handler function type used for testing.
|
|
type testHandlerFunc func(ctx context.Context, request any) (platform_http.Errorer, error)
|
|
|
|
func TestCustomMiddlewareAfterAuth(t *testing.T) {
|
|
var (
|
|
i = 0
|
|
beforeIndex = 0
|
|
authIndex = 0
|
|
afterFirstIndex = 0
|
|
afterSecondIndex = 0
|
|
)
|
|
beforeAuthMiddleware := func(next endpoint.Endpoint) endpoint.Endpoint {
|
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
i++
|
|
beforeIndex = i
|
|
return next(ctx, req)
|
|
}
|
|
}
|
|
|
|
authMiddleware := func(next endpoint.Endpoint) endpoint.Endpoint {
|
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
i++
|
|
authIndex = i
|
|
if authctx, ok := authz_ctx.FromContext(ctx); ok {
|
|
authctx.SetChecked()
|
|
}
|
|
return next(ctx, req)
|
|
}
|
|
}
|
|
|
|
afterAuthMiddlewareFirst := func(next endpoint.Endpoint) endpoint.Endpoint {
|
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
i++
|
|
afterFirstIndex = i
|
|
return next(ctx, req)
|
|
}
|
|
}
|
|
afterAuthMiddlewareSecond := func(next endpoint.Endpoint) endpoint.Endpoint {
|
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
i++
|
|
afterSecondIndex = i
|
|
return next(ctx, req)
|
|
}
|
|
}
|
|
|
|
r := mux.NewRouter()
|
|
ce := &CommonEndpointer[testHandlerFunc]{
|
|
EP: nopEP{},
|
|
MakeDecoderFn: func(iface any, requestBodySizeLimit int64) kithttp.DecodeRequestFunc {
|
|
return func(ctx context.Context, r *http.Request) (request any, err error) {
|
|
return nopRequest{}, nil
|
|
}
|
|
},
|
|
EncodeFn: func(ctx context.Context, w http.ResponseWriter, i any) error {
|
|
w.WriteHeader(http.StatusOK)
|
|
return nil
|
|
},
|
|
AuthMiddleware: authMiddleware,
|
|
CustomMiddleware: []endpoint.Middleware{
|
|
beforeAuthMiddleware,
|
|
},
|
|
CustomMiddlewareAfterAuth: []endpoint.Middleware{
|
|
afterAuthMiddlewareFirst,
|
|
afterAuthMiddlewareSecond,
|
|
},
|
|
Router: r,
|
|
}
|
|
ce.handleEndpoint("/", func(ctx context.Context, request any) (platform_http.Errorer, error) {
|
|
fmt.Printf("handler\n")
|
|
return nopResponse{}, nil
|
|
}, nil, "GET")
|
|
|
|
s := httptest.NewServer(r)
|
|
t.Cleanup(func() {
|
|
s.Close()
|
|
})
|
|
|
|
req, err := http.NewRequest("GET", s.URL+"/", nil)
|
|
require.NoError(t, err)
|
|
resp, err := http.DefaultClient.Do(req)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
resp.Body.Close()
|
|
})
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
require.Equal(t, 1, beforeIndex)
|
|
require.Equal(t, 2, authIndex)
|
|
require.Equal(t, 3, afterFirstIndex)
|
|
require.Equal(t, 4, afterSecondIndex)
|
|
}
|
|
|
|
type nopRequest struct{}
|
|
|
|
type nopResponse struct{}
|
|
|
|
func (n nopResponse) Error() error {
|
|
return nil
|
|
}
|
|
|
|
type nopEP struct{}
|
|
|
|
func (n nopEP) CallHandlerFunc(f testHandlerFunc, ctx context.Context, request any, svc any) (platform_http.Errorer, error) {
|
|
return f(ctx, request)
|
|
}
|
|
|
|
func (n nopEP) Service() any {
|
|
return nil
|
|
}
|
|
|
|
func TestRegisterDeprecatedPathAliases(t *testing.T) {
|
|
// Set up a router and register a primary endpoint via CommonEndpointer.
|
|
r := mux.NewRouter()
|
|
registry := NewHandlerRegistry()
|
|
versions := []string{"v1", "2022-04"}
|
|
|
|
authMiddleware := func(next endpoint.Endpoint) endpoint.Endpoint {
|
|
return func(ctx context.Context, req any) (any, error) {
|
|
if authctx, ok := authz_ctx.FromContext(ctx); ok {
|
|
authctx.SetChecked()
|
|
}
|
|
return next(ctx, req)
|
|
}
|
|
}
|
|
|
|
ce := &CommonEndpointer[testHandlerFunc]{
|
|
EP: nopEP{},
|
|
MakeDecoderFn: func(iface any, requestBodySizeLimit int64) kithttp.DecodeRequestFunc {
|
|
return func(ctx context.Context, r *http.Request) (request any, err error) {
|
|
return nopRequest{}, nil
|
|
}
|
|
},
|
|
EncodeFn: func(ctx context.Context, w http.ResponseWriter, i any) error {
|
|
w.WriteHeader(http.StatusOK)
|
|
return nil
|
|
},
|
|
AuthMiddleware: authMiddleware,
|
|
Router: r,
|
|
Versions: versions,
|
|
HandlerRegistry: registry,
|
|
}
|
|
|
|
// Register the primary endpoint.
|
|
ce.GET("/api/_version_/fleet/fleets", func(ctx context.Context, request any) (platform_http.Errorer, error) {
|
|
return nopResponse{}, nil
|
|
}, nil)
|
|
|
|
// Register a deprecated alias for it.
|
|
RegisterDeprecatedPathAliases(r, versions, registry, []DeprecatedPathAlias{
|
|
{
|
|
Method: "GET",
|
|
PrimaryPath: "/api/_version_/fleet/fleets",
|
|
DeprecatedPaths: []string{"/api/_version_/fleet/teams"},
|
|
},
|
|
})
|
|
|
|
s := httptest.NewServer(r)
|
|
t.Cleanup(s.Close)
|
|
|
|
// Both the primary and deprecated paths should return 200.
|
|
for _, path := range []string{"/api/v1/fleet/fleets", "/api/v1/fleet/teams", "/api/latest/fleet/teams"} {
|
|
resp, err := http.Get(s.URL + path)
|
|
require.NoError(t, err)
|
|
resp.Body.Close()
|
|
require.Equal(t, http.StatusOK, resp.StatusCode, "path %s should return 200", path)
|
|
}
|
|
}
|
|
|
|
func TestLogDeprecatedPathAlias(t *testing.T) {
|
|
// Without deprecated path info in context, LogDeprecatedPathAlias is a no-op.
|
|
lc := &logging.LoggingContext{}
|
|
ctx := logging.NewContext(context.Background(), lc)
|
|
ctx2 := LogDeprecatedPathAlias(ctx, nil)
|
|
require.Equal(t, ctx, ctx2, "should return same context when no deprecated path info")
|
|
require.Empty(t, lc.Extras)
|
|
|
|
// With deprecated path info, it should set warn level and extras.
|
|
ctx = context.WithValue(ctx, deprecatedPathInfoKey{}, deprecatedPathInfo{
|
|
deprecatedPath: "/api/_version_/fleet/teams",
|
|
primaryPath: "/api/_version_/fleet/fleets",
|
|
})
|
|
LogDeprecatedPathAlias(ctx, nil)
|
|
|
|
// Extras is a flat []interface{} of key-value pairs.
|
|
require.Len(t, lc.Extras, 4) // "deprecated_path", value, "deprecation_warning", value
|
|
require.Equal(t, "deprecated_path", lc.Extras[0])
|
|
require.Equal(t, "/api/_version_/fleet/teams", lc.Extras[1])
|
|
require.Equal(t, "deprecation_warning", lc.Extras[2])
|
|
require.Contains(t, lc.Extras[3], "deprecated")
|
|
|
|
// ForceLevel should be set to Warn.
|
|
require.NotNil(t, lc.ForceLevel)
|
|
require.Equal(t, slog.LevelWarn, *lc.ForceLevel)
|
|
}
|
|
|
|
func TestRegisterDeprecatedPathAliasesPanicsOnMissing(t *testing.T) {
|
|
r := mux.NewRouter()
|
|
registry := NewHandlerRegistry()
|
|
versions := []string{"v1"}
|
|
|
|
require.Panics(t, func() {
|
|
RegisterDeprecatedPathAliases(r, versions, registry, []DeprecatedPathAlias{
|
|
{
|
|
Method: "GET",
|
|
PrimaryPath: "/api/_version_/fleet/nonexistent",
|
|
DeprecatedPaths: []string{"/api/_version_/fleet/old"},
|
|
},
|
|
})
|
|
})
|
|
}
|
|
|
|
func defaultJSONUnmarshal(body io.Reader, req any) error {
|
|
return json.NewDecoder(body).Decode(req)
|
|
}
|
|
|
|
type testRequestDecoderType struct {
|
|
Data string `json:"data"`
|
|
}
|
|
|
|
func (d *testRequestDecoderType) DecodeRequest(ctx context.Context, r *http.Request) (any, error) {
|
|
err := json.NewDecoder(r.Body).Decode(d)
|
|
return d, err
|
|
}
|
|
|
|
// TestMakeDecoderRequestDecoderFalsePositive verifies that a body containing
|
|
// malformed JSON that is within the size limit does not produce a
|
|
// PayloadTooLargeError (false positive)
|
|
func TestMakeDecoderRequestDecoderFalsePositive(t *testing.T) {
|
|
const limit = 50
|
|
|
|
makeDecoder := func(limit int64) kithttp.DecodeRequestFunc {
|
|
return MakeDecoder(&testRequestDecoderType{}, defaultJSONUnmarshal, nil, nil, nil, nil, limit)
|
|
}
|
|
|
|
t.Run("malformed JSON within limit returns decode error, not 413", func(t *testing.T) {
|
|
body := strings.NewReader(`{"data": "truncated`) // malformed, within limit
|
|
r := httptest.NewRequest("POST", "/", body)
|
|
_, err := makeDecoder(limit)(context.Background(), r)
|
|
require.Error(t, err)
|
|
var ple platform_http.PayloadTooLargeError
|
|
require.False(t, errors.As(err, &ple), "malformed body within limit must not produce PayloadTooLargeError")
|
|
})
|
|
|
|
t.Run("body over limit returns 413", func(t *testing.T) {
|
|
big := `{"data":"` + strings.Repeat("x", limit+10) + `"}`
|
|
body := strings.NewReader(big)
|
|
r := httptest.NewRequest("POST", "/", body)
|
|
_, err := makeDecoder(limit)(context.Background(), r)
|
|
require.Error(t, err)
|
|
var ple platform_http.PayloadTooLargeError
|
|
require.True(t, errors.As(err, &ple), "body over limit must produce PayloadTooLargeError, got: %v", err)
|
|
})
|
|
|
|
t.Run("malformed JSON exactly at limit returns decode error, not 413", func(t *testing.T) {
|
|
// Build a body of exactly `limit` bytes that is malformed JSON (no closing
|
|
// brace). The MaxBytesReader allows the full read (body fits within limit),
|
|
// so the JSON decoder sees io.ErrUnexpectedEOF — not *http.MaxBytesError.
|
|
// Must not produce PayloadTooLargeError.
|
|
prefix := `{"data":"`
|
|
body := strings.NewReader(prefix + strings.Repeat("x", limit-len(prefix))) // exactly limit bytes, no closing
|
|
r := httptest.NewRequest("POST", "/", body)
|
|
_, err := makeDecoder(limit)(context.Background(), r)
|
|
require.Error(t, err)
|
|
var ple platform_http.PayloadTooLargeError
|
|
require.False(t, errors.As(err, &ple), "malformed body exactly at limit must not produce PayloadTooLargeError")
|
|
})
|
|
|
|
t.Run("body over limit without Content-Length returns 413", func(t *testing.T) {
|
|
// Simulate a chunked request (no Content-Length) whose body exceeds the
|
|
// limit. The peek at the underlying reader finds more data → 413.
|
|
big := `{"data":"` + strings.Repeat("x", limit+10) + `"}`
|
|
r := httptest.NewRequest("POST", "/", strings.NewReader(big))
|
|
r.ContentLength = -1 // strip the Content-Length that httptest set
|
|
_, err := makeDecoder(limit)(context.Background(), r)
|
|
require.Error(t, err)
|
|
var ple platform_http.PayloadTooLargeError
|
|
require.True(t, errors.As(err, &ple), "over-limit body without Content-Length must produce PayloadTooLargeError, got: %v", err)
|
|
})
|
|
|
|
t.Run("valid body within limit is decoded successfully", func(t *testing.T) {
|
|
body := strings.NewReader(`{"data":"hello"}`)
|
|
r := httptest.NewRequest("POST", "/", body)
|
|
result, err := makeDecoder(limit)(context.Background(), r)
|
|
require.NoError(t, err)
|
|
rd, ok := result.(*testRequestDecoderType)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "hello", rd.Data)
|
|
})
|
|
}
|
|
|
|
func TestMakeDecoderRequestDecoderGzipBomb(t *testing.T) {
|
|
const limit = 100
|
|
|
|
makeDecoder := func() kithttp.DecodeRequestFunc {
|
|
return MakeDecoder(&testRequestDecoderType{}, defaultJSONUnmarshal, nil, nil, nil, nil, limit)
|
|
}
|
|
|
|
gzipBody := func(data string) *bytes.Buffer {
|
|
var buf bytes.Buffer
|
|
gw := gzip.NewWriter(&buf)
|
|
_, err := gw.Write([]byte(data))
|
|
require.NoError(t, err)
|
|
require.NoError(t, gw.Close())
|
|
return &buf
|
|
}
|
|
|
|
t.Run("gzip bomb exceeding decompressed limit returns 413", func(t *testing.T) {
|
|
big := `{"data":"` + strings.Repeat("x", limit*10) + `"}`
|
|
r := httptest.NewRequest("POST", "/", gzipBody(big))
|
|
r.Header.Set("Content-Encoding", "gzip")
|
|
_, err := makeDecoder()(context.Background(), r)
|
|
require.Error(t, err)
|
|
var ple platform_http.PayloadTooLargeError
|
|
require.True(t, errors.As(err, &ple), "gzip bomb via RequestDecoder must produce PayloadTooLargeError, got: %v", err)
|
|
assert.True(t, ple.Gzipped, "PayloadTooLargeError from gzip bomb must have Gzipped set")
|
|
})
|
|
|
|
t.Run("valid gzip body within limit is decoded successfully", func(t *testing.T) {
|
|
r := httptest.NewRequest("POST", "/", gzipBody(`{"data":"hi"}`))
|
|
r.Header.Set("Content-Encoding", "gzip")
|
|
result, err := makeDecoder()(context.Background(), r)
|
|
require.NoError(t, err)
|
|
rd, ok := result.(*testRequestDecoderType)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "hi", rd.Data)
|
|
})
|
|
|
|
t.Run("Content-Encoding header is cleared after decompression", func(t *testing.T) {
|
|
r := httptest.NewRequest("POST", "/", gzipBody(`{"data":"hi"}`))
|
|
r.Header.Set("Content-Encoding", "gzip")
|
|
_, err := makeDecoder()(context.Background(), r)
|
|
require.NoError(t, err)
|
|
assert.Empty(t, r.Header.Get("Content-Encoding"), "Content-Encoding should be cleared after framework decompression")
|
|
})
|
|
}
|
|
|
|
type testGzipRequestType struct {
|
|
Data string `json:"data"`
|
|
}
|
|
|
|
// testRequestDecoderPayloadTooLargeType implements RequestDecoder and returns
|
|
// a PayloadTooLargeError directly from DecodeRequest (simulating implementations
|
|
// like getHostSoftwareRequest that enforce their own size limits).
|
|
type testRequestDecoderPayloadTooLargeType struct {
|
|
Data string `json:"data"`
|
|
}
|
|
|
|
func (d *testRequestDecoderPayloadTooLargeType) DecodeRequest(_ context.Context, r *http.Request) (any, error) {
|
|
return nil, platform_http.PayloadTooLargeError{
|
|
ContentLength: r.Header.Get("Content-Length"),
|
|
MaxRequestSize: 42,
|
|
}
|
|
}
|
|
|
|
type testGzipBodyDecoderType struct {
|
|
Data string `json:"data"`
|
|
}
|
|
|
|
func TestMakeDecoderGzipBomb(t *testing.T) {
|
|
const limit = 100
|
|
|
|
makeDecoder := func() kithttp.DecodeRequestFunc {
|
|
return MakeDecoder(testGzipRequestType{}, defaultJSONUnmarshal, nil, nil, nil, nil, limit)
|
|
}
|
|
|
|
gzipBody := func(data string) *bytes.Buffer {
|
|
var buf bytes.Buffer
|
|
gw := gzip.NewWriter(&buf)
|
|
_, err := gw.Write([]byte(data))
|
|
require.NoError(t, err)
|
|
require.NoError(t, gw.Close())
|
|
return &buf
|
|
}
|
|
|
|
t.Run("gzip bomb exceeding decompressed limit returns 413", func(t *testing.T) {
|
|
// Compressed payload is small but decompresses well beyond the limit.
|
|
big := `{"data":"` + strings.Repeat("x", limit*10) + `"}`
|
|
r := httptest.NewRequest("POST", "/", gzipBody(big))
|
|
r.Header.Set("Content-Encoding", "gzip")
|
|
_, err := makeDecoder()(context.Background(), r)
|
|
require.Error(t, err)
|
|
var ple platform_http.PayloadTooLargeError
|
|
require.True(t, errors.As(err, &ple), "gzip bomb must produce PayloadTooLargeError, got: %v", err)
|
|
assert.True(t, ple.Gzipped, "PayloadTooLargeError from gzip bomb must have Gzipped set")
|
|
})
|
|
|
|
t.Run("valid gzip body within limit is decoded successfully", func(t *testing.T) {
|
|
r := httptest.NewRequest("POST", "/", gzipBody(`{"data":"hi"}`))
|
|
r.Header.Set("Content-Encoding", "gzip")
|
|
result, err := makeDecoder()(context.Background(), r)
|
|
require.NoError(t, err)
|
|
rd, ok := result.(*testGzipRequestType)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "hi", rd.Data)
|
|
})
|
|
|
|
// Sub-tests for the bodyDecoder (DecodeBody) code path, where isBodyDecoder
|
|
// returns true and decodeBody is called instead of jsonUnmarshal.
|
|
isBodyDecoder := func(v reflect.Value) bool {
|
|
_, ok := v.Interface().(*testGzipBodyDecoderType)
|
|
return ok
|
|
}
|
|
decodeBodyFn := func(_ context.Context, _ *http.Request, v reflect.Value, body io.Reader) error {
|
|
bd := v.Interface().(*testGzipBodyDecoderType)
|
|
return json.NewDecoder(body).Decode(bd)
|
|
}
|
|
makeBodyDecoder := func() kithttp.DecodeRequestFunc {
|
|
return MakeDecoder(testGzipBodyDecoderType{}, defaultJSONUnmarshal, nil, isBodyDecoder, decodeBodyFn, nil, limit)
|
|
}
|
|
|
|
t.Run("DecodeBody gzip bomb exceeding decompressed limit returns 413", func(t *testing.T) {
|
|
big := `{"data":"` + strings.Repeat("x", limit*10) + `"}`
|
|
r := httptest.NewRequest("POST", "/", gzipBody(big))
|
|
r.Header.Set("Content-Encoding", "gzip")
|
|
_, err := makeBodyDecoder()(context.Background(), r)
|
|
require.Error(t, err)
|
|
var ple platform_http.PayloadTooLargeError
|
|
require.True(t, errors.As(err, &ple), "gzip bomb via DecodeBody must produce PayloadTooLargeError, got: %v", err)
|
|
assert.True(t, ple.Gzipped, "PayloadTooLargeError from gzip bomb must have Gzipped set")
|
|
})
|
|
|
|
t.Run("DecodeBody valid gzip body within limit is decoded successfully", func(t *testing.T) {
|
|
r := httptest.NewRequest("POST", "/", gzipBody(`{"data":"hi"}`))
|
|
r.Header.Set("Content-Encoding", "gzip")
|
|
result, err := makeBodyDecoder()(context.Background(), r)
|
|
require.NoError(t, err)
|
|
rd, ok := result.(*testGzipBodyDecoderType)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "hi", rd.Data)
|
|
})
|
|
|
|
// Sub-test for the RequestDecoder code path where DecodeRequest itself
|
|
// returns a PayloadTooLargeError (covers lines 545-546).
|
|
t.Run("DecodeRequest returning PayloadTooLargeError preserves inner fields and sets Gzipped", func(t *testing.T) {
|
|
makePayloadDecoder := func() kithttp.DecodeRequestFunc {
|
|
return MakeDecoder(&testRequestDecoderPayloadTooLargeType{}, defaultJSONUnmarshal, nil, nil, nil, nil, limit)
|
|
}
|
|
r := httptest.NewRequest("POST", "/", gzipBody(`{"data":"hi"}`))
|
|
r.Header.Set("Content-Encoding", "gzip")
|
|
_, err := makePayloadDecoder()(context.Background(), r)
|
|
require.Error(t, err)
|
|
var ple platform_http.PayloadTooLargeError
|
|
require.True(t, errors.As(err, &ple), "DecodeRequest returning PayloadTooLargeError must propagate, got: %v", err)
|
|
assert.True(t, ple.Gzipped, "Gzipped must be set when the request was gzip-encoded")
|
|
assert.Equal(t, int64(42), ple.MaxRequestSize, "MaxRequestSize from inner error must be preserved")
|
|
})
|
|
}
|