mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 21:47:20 +00:00
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #40538 This is the initial iteration of CSP functionality, currently gated behind FLEET_SERVER_ENABLE_CSP. If disabled, no CSP is served. Nonces are still injected into pages however a dummy nonce is used and has no effect. With this setting turned on things break and will be addressed by mainly frontend changes in https://github.com/fleetdm/fleet/issues/41577 # Checklist for submitter If some of the following don't apply, delete the relevant line. - [x] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements), JS inline code is prevented especially for url redirects - [x] If paths of existing endpoints are modified without backwards compatibility, checked the frontend/CLI for any necessary changes ## Testing - [x] Added/updated automated tests - [x] Where appropriate, [automated tests simulate multiple hosts and test for host isolation](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/reference/patterns-backend.md#unit-testing) (updates to one hosts's records do not affect another) - [x] QA'd all new/changed functionality manually --------- Co-authored-by: Gabriel Hernandez <ghernandez345@gmail.com>
641 lines
22 KiB
Go
641 lines
22 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"log/slog"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/server/contexts/installersize"
|
|
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/mock"
|
|
"github.com/fleetdm/fleet/v4/server/platform/endpointer"
|
|
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
"github.com/fleetdm/fleet/v4/server/service/middleware/auth"
|
|
"github.com/fleetdm/fleet/v4/server/service/middleware/log"
|
|
"github.com/go-kit/kit/endpoint"
|
|
kithttp "github.com/go-kit/kit/transport/http"
|
|
"github.com/gorilla/mux"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestUniversalDecoderIDs(t *testing.T) {
|
|
type universalStruct struct {
|
|
ID1 uint `url:"some-id"`
|
|
OptionalID uint `url:"some-other-id,optional"`
|
|
}
|
|
decoder := makeDecoder(universalStruct{}, installersize.MaxSoftwareInstallerSize)
|
|
|
|
req := httptest.NewRequest("POST", "/target", nil)
|
|
req = mux.SetURLVars(req, map[string]string{"some-id": "999"})
|
|
|
|
decoded, err := decoder(context.Background(), req)
|
|
require.NoError(t, err)
|
|
casted, ok := decoded.(*universalStruct)
|
|
require.True(t, ok)
|
|
|
|
assert.Equal(t, uint(999), casted.ID1)
|
|
assert.Equal(t, uint(0), casted.OptionalID)
|
|
|
|
// fails if non optional IDs are not provided
|
|
req = httptest.NewRequest("POST", "/target", nil)
|
|
_, err = decoder(context.Background(), req)
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestUniversalDecoderIDsAndJSON(t *testing.T) {
|
|
type universalStruct struct {
|
|
ID1 uint `url:"some-id"`
|
|
SomeString string `json:"some_string"`
|
|
}
|
|
decoder := makeDecoder(universalStruct{}, installersize.MaxSoftwareInstallerSize)
|
|
|
|
body := `{"some_string": "hello"}`
|
|
req := httptest.NewRequest("POST", "/target", strings.NewReader(body))
|
|
req = mux.SetURLVars(req, map[string]string{"some-id": "999"})
|
|
|
|
decoded, err := decoder(context.Background(), req)
|
|
require.NoError(t, err)
|
|
casted, ok := decoded.(*universalStruct)
|
|
require.True(t, ok)
|
|
|
|
assert.Equal(t, uint(999), casted.ID1)
|
|
assert.Equal(t, "hello", casted.SomeString)
|
|
}
|
|
|
|
func TestUniversalDecoderIDsAndJSONEmbedded(t *testing.T) {
|
|
type EmbeddedJSON struct {
|
|
SomeString string `json:"some_string"`
|
|
}
|
|
type UniversalStruct struct {
|
|
ID1 uint `url:"some-id"`
|
|
EmbeddedJSON
|
|
}
|
|
decoder := makeDecoder(UniversalStruct{}, installersize.MaxSoftwareInstallerSize)
|
|
|
|
body := `{"some_string": "hello"}`
|
|
req := httptest.NewRequest("POST", "/target", strings.NewReader(body))
|
|
req = mux.SetURLVars(req, map[string]string{"some-id": "999"})
|
|
|
|
decoded, err := decoder(context.Background(), req)
|
|
require.NoError(t, err)
|
|
casted, ok := decoded.(*UniversalStruct)
|
|
require.True(t, ok)
|
|
|
|
assert.Equal(t, uint(999), casted.ID1)
|
|
assert.Equal(t, "hello", casted.SomeString)
|
|
}
|
|
|
|
func TestUniversalDecoderIDsAndListOptions(t *testing.T) {
|
|
type universalStruct struct {
|
|
ID1 uint `url:"some-id"`
|
|
Opts fleet.ListOptions `url:"list_options"`
|
|
SomeString string `json:"some_string"`
|
|
}
|
|
decoder := makeDecoder(universalStruct{}, installersize.MaxSoftwareInstallerSize)
|
|
|
|
body := `{"some_string": "bye"}`
|
|
req := httptest.NewRequest("POST", "/target?per_page=77&page=4", strings.NewReader(body))
|
|
req = mux.SetURLVars(req, map[string]string{"some-id": "123"})
|
|
|
|
decoded, err := decoder(context.Background(), req)
|
|
require.NoError(t, err)
|
|
casted, ok := decoded.(*universalStruct)
|
|
require.True(t, ok)
|
|
|
|
assert.Equal(t, uint(123), casted.ID1)
|
|
assert.Equal(t, "bye", casted.SomeString)
|
|
assert.Equal(t, uint(77), casted.Opts.PerPage)
|
|
assert.Equal(t, uint(4), casted.Opts.Page)
|
|
}
|
|
|
|
func TestUniversalDecoderHandlersEmbeddedAndNot(t *testing.T) {
|
|
type EmbeddedJSON struct {
|
|
SomeString string `json:"some_string"`
|
|
}
|
|
type universalStruct struct {
|
|
ID1 uint `url:"some-id"`
|
|
Opts fleet.ListOptions `url:"list_options"`
|
|
EmbeddedJSON
|
|
}
|
|
decoder := makeDecoder(universalStruct{}, installersize.MaxSoftwareInstallerSize)
|
|
|
|
body := `{"some_string": "o/"}`
|
|
req := httptest.NewRequest("POST", "/target?per_page=77&page=4", strings.NewReader(body))
|
|
req = mux.SetURLVars(req, map[string]string{"some-id": "123"})
|
|
|
|
decoded, err := decoder(context.Background(), req)
|
|
require.NoError(t, err)
|
|
casted, ok := decoded.(*universalStruct)
|
|
require.True(t, ok)
|
|
|
|
assert.Equal(t, uint(123), casted.ID1)
|
|
assert.Equal(t, "o/", casted.SomeString)
|
|
assert.Equal(t, uint(77), casted.Opts.PerPage)
|
|
assert.Equal(t, uint(4), casted.Opts.Page)
|
|
}
|
|
|
|
func TestUniversalDecoderListOptions(t *testing.T) {
|
|
type universalStruct struct {
|
|
ID1 uint `url:"some-id"`
|
|
Opts fleet.ListOptions `url:"list_options"`
|
|
}
|
|
decoder := makeDecoder(universalStruct{}, installersize.MaxSoftwareInstallerSize)
|
|
|
|
req := httptest.NewRequest("POST", "/target", nil)
|
|
req = mux.SetURLVars(req, map[string]string{"some-id": "123"})
|
|
|
|
decoded, err := decoder(context.Background(), req)
|
|
require.NoError(t, err)
|
|
_, ok := decoded.(*universalStruct)
|
|
require.True(t, ok)
|
|
}
|
|
|
|
func TestUniversalDecoderOptionalQueryParams(t *testing.T) {
|
|
type universalStruct struct {
|
|
ID1 *uint `query:"some_id,optional"`
|
|
}
|
|
decoder := makeDecoder(universalStruct{}, installersize.MaxSoftwareInstallerSize)
|
|
|
|
req := httptest.NewRequest("POST", "/target", nil)
|
|
|
|
decoded, err := decoder(context.Background(), req)
|
|
require.NoError(t, err)
|
|
casted, ok := decoded.(*universalStruct)
|
|
require.True(t, ok)
|
|
|
|
assert.Nil(t, casted.ID1)
|
|
|
|
req = httptest.NewRequest("POST", "/target?some_id=321", nil)
|
|
|
|
decoded, err = decoder(context.Background(), req)
|
|
require.NoError(t, err)
|
|
casted, ok = decoded.(*universalStruct)
|
|
require.True(t, ok)
|
|
|
|
require.NotNil(t, casted.ID1)
|
|
assert.Equal(t, uint(321), *casted.ID1)
|
|
}
|
|
|
|
func TestUniversalDecoderOptionalQueryParamString(t *testing.T) {
|
|
type universalStruct struct {
|
|
ID1 *string `query:"some_val,optional"`
|
|
}
|
|
decoder := makeDecoder(universalStruct{}, installersize.MaxSoftwareInstallerSize)
|
|
|
|
req := httptest.NewRequest("POST", "/target", nil)
|
|
|
|
decoded, err := decoder(context.Background(), req)
|
|
require.NoError(t, err)
|
|
casted, ok := decoded.(*universalStruct)
|
|
require.True(t, ok)
|
|
|
|
assert.Nil(t, casted.ID1)
|
|
|
|
req = httptest.NewRequest("POST", "/target?some_val=321", nil)
|
|
|
|
decoded, err = decoder(context.Background(), req)
|
|
require.NoError(t, err)
|
|
casted, ok = decoded.(*universalStruct)
|
|
require.True(t, ok)
|
|
|
|
require.NotNil(t, casted.ID1)
|
|
assert.Equal(t, "321", *casted.ID1)
|
|
}
|
|
|
|
func TestUniversalDecoderOptionalQueryParamNotPtr(t *testing.T) {
|
|
type universalStruct struct {
|
|
ID1 string `query:"some_val,optional"`
|
|
}
|
|
decoder := makeDecoder(universalStruct{}, installersize.MaxSoftwareInstallerSize)
|
|
|
|
req := httptest.NewRequest("POST", "/target", nil)
|
|
|
|
decoded, err := decoder(context.Background(), req)
|
|
require.NoError(t, err)
|
|
casted, ok := decoded.(*universalStruct)
|
|
require.True(t, ok)
|
|
|
|
assert.Equal(t, "", casted.ID1)
|
|
|
|
req = httptest.NewRequest("POST", "/target?some_val=321", nil)
|
|
|
|
decoded, err = decoder(context.Background(), req)
|
|
require.NoError(t, err)
|
|
casted, ok = decoded.(*universalStruct)
|
|
require.True(t, ok)
|
|
|
|
assert.Equal(t, "321", casted.ID1)
|
|
}
|
|
|
|
func TestUniversalDecoderQueryAndListPlayNice(t *testing.T) {
|
|
type universalStruct struct {
|
|
ID1 *uint `query:"some_id"`
|
|
Opts fleet.ListOptions `url:"list_options"`
|
|
}
|
|
decoder := makeDecoder(universalStruct{}, installersize.MaxSoftwareInstallerSize)
|
|
|
|
req := httptest.NewRequest("POST", "/target?per_page=77&page=4&some_id=444", nil)
|
|
|
|
decoded, err := decoder(context.Background(), req)
|
|
require.NoError(t, err)
|
|
casted, ok := decoded.(*universalStruct)
|
|
require.True(t, ok)
|
|
|
|
assert.Equal(t, uint(77), casted.Opts.PerPage)
|
|
assert.Equal(t, uint(4), casted.Opts.Page)
|
|
require.NotNil(t, casted.ID1)
|
|
assert.Equal(t, uint(444), *casted.ID1)
|
|
}
|
|
|
|
func TestUniversalDecoderSizeLimit(t *testing.T) {
|
|
type universalStruct struct {
|
|
ID1 uint `url:"some-id"`
|
|
Opts fleet.ListOptions `url:"list_options"`
|
|
}
|
|
decoder := makeDecoder(universalStruct{}, platform_http.MaxRequestBodySize)
|
|
|
|
// Body larger than the limit should return PayloadTooLargeError.
|
|
largeBody := `{"key": "` + strings.Repeat("A", int(platform_http.MaxRequestBodySize)+1) + `"}`
|
|
req := httptest.NewRequest("POST", "/target?per_page=77&page=4", strings.NewReader(largeBody))
|
|
req = mux.SetURLVars(req, map[string]string{"some-id": "123"})
|
|
|
|
_, err := decoder(context.Background(), req)
|
|
require.Error(t, err)
|
|
require.IsType(t, platform_http.PayloadTooLargeError{}, err)
|
|
|
|
// Body within the limit but with broken JSON
|
|
incompleteBody := `{"key": "` + strings.Repeat("A", 100) // missing closing "}
|
|
req = httptest.NewRequest("POST", "/target?per_page=77&page=4", strings.NewReader(incompleteBody))
|
|
req = mux.SetURLVars(req, map[string]string{"some-id": "123"})
|
|
|
|
_, err = decoder(context.Background(), req)
|
|
require.Error(t, err)
|
|
require.True(t, errors.Is(err, io.ErrUnexpectedEOF), "expected io.ErrUnexpectedEOF, got %T: %v", err, err)
|
|
_, isPayloadTooLarge := err.(platform_http.PayloadTooLargeError)
|
|
require.False(t, isPayloadTooLarge, "incomplete body within size limit must not produce PayloadTooLargeError, got %T: %v", err, err)
|
|
|
|
// Body within the limit and complete ... OK
|
|
|
|
largeBody = `{"key": "` + strings.Repeat("A", int(platform_http.MaxRequestBodySize)-11) + `"}` // -11 to account for the wrapping JSON
|
|
req = httptest.NewRequest("POST", "/target?per_page=77&page=4", strings.NewReader(largeBody))
|
|
req = mux.SetURLVars(req, map[string]string{"some-id": "123"})
|
|
|
|
_, err = decoder(context.Background(), req)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
type stringErrorer string
|
|
|
|
func (s stringErrorer) Error() error { return nil }
|
|
|
|
func TestEndpointer(t *testing.T) {
|
|
r := mux.NewRouter()
|
|
ds := new(mock.Store)
|
|
ds.SessionByKeyFunc = func(ctx context.Context, key string) (*fleet.Session, error) {
|
|
return &fleet.Session{
|
|
ID: 3,
|
|
UserID: 42,
|
|
Key: key,
|
|
AccessedAt: time.Now(),
|
|
}, nil
|
|
}
|
|
ds.DestroySessionFunc = func(ctx context.Context, session *fleet.Session) error {
|
|
return nil
|
|
}
|
|
ds.MarkSessionAccessedFunc = func(ctx context.Context, session *fleet.Session) error {
|
|
return nil
|
|
}
|
|
ds.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
|
|
return &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}, nil
|
|
}
|
|
ds.ListUsersFunc = func(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) {
|
|
return []*fleet.User{{GlobalRole: ptr.String(fleet.RoleAdmin)}}, nil
|
|
}
|
|
|
|
svc, _ := newTestService(t, ds, nil, nil)
|
|
|
|
fleetAPIOptions := []kithttp.ServerOption{
|
|
kithttp.ServerBefore(
|
|
kithttp.PopulateRequestContext, // populate the request context with common fields
|
|
auth.SetRequestsContexts(svc),
|
|
),
|
|
kithttp.ServerErrorHandler(&endpointer.ErrorHandler{Logger: slog.New(slog.DiscardHandler)}),
|
|
kithttp.ServerErrorEncoder(fleetErrorEncoder),
|
|
kithttp.ServerAfter(
|
|
kithttp.SetContentType("application/json; charset=utf-8"),
|
|
log.LogRequestEnd(slog.New(slog.DiscardHandler)),
|
|
checkLicenseExpiration(svc),
|
|
),
|
|
}
|
|
|
|
e := newUserAuthenticatedEndpointer(svc, fleetAPIOptions, r, "v1", "2021-11")
|
|
nopHandler := func(ctx context.Context, request interface{}, svc fleet.Service) (fleet.Errorer, error) {
|
|
setAuthCheckedOnPreAuthErr(ctx)
|
|
return stringErrorer("nop"), nil
|
|
}
|
|
overrideHandler := func(ctx context.Context, request interface{}, svc fleet.Service) (fleet.Errorer, error) {
|
|
setAuthCheckedOnPreAuthErr(ctx)
|
|
return stringErrorer("override"), nil
|
|
}
|
|
|
|
// Regular path, no plan to deprecate
|
|
e.GET("/api/_version_/fleet/path1", nopHandler, struct{}{})
|
|
|
|
// New path, we want it only available starting from the specified version
|
|
e.StartingAtVersion("2021-11").GET("/api/_version_/fleet/newpath", nopHandler, struct{}{})
|
|
|
|
// Path that was in v1, but was changed in 2021-11
|
|
e.EndingAtVersion("v1").GET("/api/_version_/fleet/overriddenpath", nopHandler, struct{}{})
|
|
e.StartingAtVersion("2021-11").GET("/api/_version_/fleet/overriddenpath", overrideHandler, struct{}{})
|
|
|
|
// Path that got deprecated
|
|
e.EndingAtVersion("v1").GET("/api/_version_/fleet/deprecated", nopHandler, struct{}{})
|
|
// Path that got deprecated but in the latest version
|
|
e.EndingAtVersion("2021-11").GET("/api/_version_/fleet/deprecated-soon", nopHandler, struct{}{})
|
|
|
|
// Aliasing works with versioning too
|
|
e.WithAltPaths("/api/_version_/fleet/something/{fff}").GET("/api/_version_/fleet/somethings/{fff}", nopHandler, struct{}{})
|
|
|
|
mustMatch := []struct {
|
|
method string
|
|
path string
|
|
overridden bool
|
|
}{
|
|
{method: "GET", path: "/api/v1/fleet/path1"},
|
|
{method: "GET", path: "/api/2021-11/fleet/path1"},
|
|
{method: "GET", path: "/api/latest/fleet/path1"},
|
|
|
|
{method: "GET", path: "/api/2021-11/fleet/newpath"},
|
|
{method: "GET", path: "/api/latest/fleet/newpath"},
|
|
|
|
{method: "GET", path: "/api/v1/fleet/deprecated"},
|
|
|
|
{method: "GET", path: "/api/v1/fleet/deprecated-soon"},
|
|
{method: "GET", path: "/api/2021-11/fleet/deprecated-soon"},
|
|
{method: "GET", path: "/api/latest/fleet/deprecated-soon"},
|
|
|
|
{method: "GET", path: "/api/v1/fleet/overriddenpath"},
|
|
{method: "GET", path: "/api/2021-11/fleet/overriddenpath", overridden: true},
|
|
{method: "GET", path: "/api/latest/fleet/overriddenpath", overridden: true},
|
|
|
|
{method: "GET", path: "/api/v1/fleet/something/aaa"},
|
|
{method: "GET", path: "/api/2021-11/fleet/something/aaa"},
|
|
{method: "GET", path: "/api/latest/fleet/something/aaa"},
|
|
{method: "GET", path: "/api/v1/fleet/somethings/aaa"},
|
|
{method: "GET", path: "/api/2021-11/fleet/somethings/aaa"},
|
|
{method: "GET", path: "/api/latest/fleet/somethings/aaa"},
|
|
}
|
|
|
|
mustNotMatch := []struct {
|
|
method string
|
|
path string
|
|
handler http.Handler
|
|
}{
|
|
{method: "POST", path: "/api/v1/fleet/path1"},
|
|
{method: "GET", path: "/api/v1/fleet/qwejoqiwejqiowehioqwe"},
|
|
{method: "GET", path: "/api/v1/qwejoqiwejqiowehioqwe"},
|
|
|
|
{method: "GET", path: "/api/v1/fleet/newpath"},
|
|
|
|
{method: "GET", path: "/api/2021-11/fleet/deprecated"},
|
|
{method: "GET", path: "/api/latest/fleet/deprecated"},
|
|
}
|
|
|
|
doesItMatch := func(method, path string, override bool) bool {
|
|
testURL := url.URL{Path: path}
|
|
request := http.Request{Method: method, URL: &testURL, Header: map[string][]string{"Authorization": {"Bearer asd"}}, Body: io.NopCloser(strings.NewReader(""))}
|
|
routeMatch := mux.RouteMatch{}
|
|
|
|
res := r.Match(&request, &routeMatch)
|
|
if routeMatch.Route != nil {
|
|
rec := httptest.NewRecorder()
|
|
routeMatch.Handler.ServeHTTP(rec, &request)
|
|
got := rec.Body.String()
|
|
if override {
|
|
require.Equal(t, "\"override\"\n", got)
|
|
} else {
|
|
require.Equal(t, "\"nop\"\n", got)
|
|
}
|
|
}
|
|
return res && routeMatch.MatchErr == nil && routeMatch.Route != nil
|
|
}
|
|
|
|
for _, route := range mustMatch {
|
|
require.True(t, doesItMatch(route.method, route.path, route.overridden), route)
|
|
}
|
|
|
|
for _, route := range mustNotMatch {
|
|
require.False(t, doesItMatch(route.method, route.path, false), route)
|
|
}
|
|
}
|
|
|
|
func TestEndpointerCustomMiddleware(t *testing.T) {
|
|
r := mux.NewRouter()
|
|
ds := new(mock.Store)
|
|
svc, _ := newTestService(t, ds, nil, nil)
|
|
|
|
fleetAPIOptions := []kithttp.ServerOption{
|
|
kithttp.ServerBefore(
|
|
kithttp.PopulateRequestContext,
|
|
auth.SetRequestsContexts(svc),
|
|
),
|
|
kithttp.ServerErrorHandler(&endpointer.ErrorHandler{Logger: slog.New(slog.DiscardHandler)}),
|
|
kithttp.ServerErrorEncoder(fleetErrorEncoder),
|
|
kithttp.ServerAfter(
|
|
kithttp.SetContentType("application/json; charset=utf-8"),
|
|
log.LogRequestEnd(slog.New(slog.DiscardHandler)),
|
|
checkLicenseExpiration(svc),
|
|
),
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
e := newNoAuthEndpointer(svc, fleetAPIOptions, r, "v1")
|
|
e.GET("/none/", func(ctx context.Context, request interface{}, svc fleet.Service) (fleet.Errorer, error) {
|
|
buf.WriteString("H1")
|
|
return nil, nil
|
|
}, nil)
|
|
|
|
e.WithCustomMiddleware(
|
|
func(e endpoint.Endpoint) endpoint.Endpoint {
|
|
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
|
|
buf.WriteString("A")
|
|
return e(ctx, request)
|
|
}
|
|
},
|
|
func(e endpoint.Endpoint) endpoint.Endpoint {
|
|
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
|
|
buf.WriteString("B")
|
|
return e(ctx, request)
|
|
}
|
|
},
|
|
func(e endpoint.Endpoint) endpoint.Endpoint {
|
|
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
|
|
buf.WriteString("C")
|
|
return e(ctx, request)
|
|
}
|
|
},
|
|
).
|
|
GET("/mw/", func(ctx context.Context, request interface{}, svc fleet.Service) (fleet.Errorer, error) {
|
|
buf.WriteString("H2")
|
|
return nil, nil
|
|
}, nil)
|
|
|
|
req := httptest.NewRequest("GET", "/none/", nil)
|
|
var m1 mux.RouteMatch
|
|
|
|
require.True(t, r.Match(req, &m1))
|
|
rec := httptest.NewRecorder()
|
|
m1.Handler.ServeHTTP(rec, req)
|
|
require.Equal(t, "H1", buf.String())
|
|
|
|
buf.Reset()
|
|
req = httptest.NewRequest("GET", "/mw/", nil)
|
|
var m2 mux.RouteMatch
|
|
|
|
require.True(t, r.Match(req, &m2))
|
|
rec = httptest.NewRecorder()
|
|
m2.Handler.ServeHTTP(rec, req)
|
|
require.Equal(t, "ABCH2", buf.String())
|
|
}
|
|
|
|
func TestWriteBrowserSecurityHeadersNoCSP(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
_, err := endpointer.WriteBrowserSecurityHeaders(w, false, false)
|
|
require.NoError(t, err)
|
|
headers := w.Header()
|
|
require.Equal(
|
|
t,
|
|
http.Header{
|
|
"X-Content-Type-Options": {"nosniff"},
|
|
"X-Frame-Options": {"SAMEORIGIN"},
|
|
"Strict-Transport-Security": {"max-age=31536000; includeSubDomains;"},
|
|
"Referrer-Policy": {"strict-origin-when-cross-origin"},
|
|
},
|
|
headers,
|
|
)
|
|
}
|
|
|
|
func TestWriteBrowserSecurityHeadersCSPNoNonce(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
_, err := endpointer.WriteBrowserSecurityHeaders(w, true, false)
|
|
require.NoError(t, err)
|
|
headers := w.Header()
|
|
require.Equal(
|
|
t,
|
|
http.Header{
|
|
"X-Content-Type-Options": {"nosniff"},
|
|
"X-Frame-Options": {"SAMEORIGIN"},
|
|
"Strict-Transport-Security": {"max-age=31536000; includeSubDomains;"},
|
|
"Referrer-Policy": {"strict-origin-when-cross-origin"},
|
|
"Content-Security-Policy": {"default-src 'none'; base-uri 'self'; connect-src 'self' www.gravatar.com ws: wss:; img-src 'self' www.gravatar.com data: https:; style-src 'self'; font-src 'self'; script-src 'self'"},
|
|
},
|
|
headers,
|
|
)
|
|
}
|
|
|
|
func TestWriteBrowserSecurityHeadersCSPAndNonce(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
nonce, err := endpointer.WriteBrowserSecurityHeaders(w, true, true)
|
|
require.NoError(t, err)
|
|
headers := w.Header()
|
|
require.Equal(
|
|
t,
|
|
http.Header{
|
|
"X-Content-Type-Options": {"nosniff"},
|
|
"X-Frame-Options": {"SAMEORIGIN"},
|
|
"Strict-Transport-Security": {"max-age=31536000; includeSubDomains;"},
|
|
"Referrer-Policy": {"strict-origin-when-cross-origin"},
|
|
"Content-Security-Policy": {"default-src 'none'; base-uri 'self'; connect-src 'self' www.gravatar.com ws: wss:; img-src 'self' www.gravatar.com data: https:; style-src 'self' 'nonce-" + nonce + "'; font-src 'self'; script-src 'self' 'nonce-" + nonce + "'"},
|
|
},
|
|
headers,
|
|
)
|
|
}
|
|
|
|
// newMultipartRequest creates an *http.Request with multipart/form-data body
|
|
// containing the given field key/value pairs.
|
|
func newMultipartRequest(t *testing.T, fields map[string]string) *http.Request {
|
|
t.Helper()
|
|
var buf bytes.Buffer
|
|
w := multipart.NewWriter(&buf)
|
|
for k, v := range fields {
|
|
require.NoError(t, w.WriteField(k, v))
|
|
}
|
|
require.NoError(t, w.Close())
|
|
req := httptest.NewRequest("POST", "/target", &buf)
|
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
|
return req
|
|
}
|
|
|
|
func TestParseMultipartForm(t *testing.T) {
|
|
t.Run("passes through fleet_id unchanged", func(t *testing.T) {
|
|
req := newMultipartRequest(t, map[string]string{"fleet_id": "42"})
|
|
logCtx := &logging.LoggingContext{}
|
|
ctx := logging.NewContext(context.Background(), logCtx)
|
|
|
|
err := parseMultipartForm(ctx, req, platform_http.MaxMultipartFormSize)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, []string{"42"}, req.MultipartForm.Value["fleet_id"])
|
|
assert.Empty(t, req.MultipartForm.Value["team_id"])
|
|
assert.Nil(t, logCtx.ForceLevel)
|
|
assert.Empty(t, logCtx.Extras)
|
|
})
|
|
|
|
t.Run("rewrites team_id to fleet_id and logs deprecation", func(t *testing.T) {
|
|
req := newMultipartRequest(t, map[string]string{"team_id": "7"})
|
|
logCtx := &logging.LoggingContext{}
|
|
ctx := logging.NewContext(context.Background(), logCtx)
|
|
|
|
err := parseMultipartForm(ctx, req, platform_http.MaxMultipartFormSize)
|
|
require.NoError(t, err)
|
|
|
|
// team_id should be removed, fleet_id should be set
|
|
assert.Equal(t, []string{"7"}, req.MultipartForm.Value["fleet_id"])
|
|
assert.Empty(t, req.MultipartForm.Value["team_id"])
|
|
|
|
// r.Form should also be updated
|
|
assert.Equal(t, "7", req.Form.Get("fleet_id"))
|
|
assert.Empty(t, req.Form.Get("team_id"))
|
|
|
|
// deprecation should be logged
|
|
require.NotNil(t, logCtx.ForceLevel)
|
|
assert.Equal(t, slog.LevelWarn, *logCtx.ForceLevel)
|
|
assert.Contains(t, logCtx.Extras, "deprecated_param")
|
|
assert.Contains(t, logCtx.Extras, "team_id")
|
|
})
|
|
|
|
t.Run("no team_id or fleet_id", func(t *testing.T) {
|
|
req := newMultipartRequest(t, map[string]string{"other_field": "hello"})
|
|
logCtx := &logging.LoggingContext{}
|
|
ctx := logging.NewContext(context.Background(), logCtx)
|
|
|
|
err := parseMultipartForm(ctx, req, platform_http.MaxMultipartFormSize)
|
|
require.NoError(t, err)
|
|
|
|
assert.Empty(t, req.MultipartForm.Value["fleet_id"])
|
|
assert.Empty(t, req.MultipartForm.Value["team_id"])
|
|
assert.Equal(t, []string{"hello"}, req.MultipartForm.Value["other_field"])
|
|
assert.Nil(t, logCtx.ForceLevel)
|
|
assert.Empty(t, logCtx.Extras)
|
|
})
|
|
|
|
t.Run("invalid body returns error", func(t *testing.T) {
|
|
req := httptest.NewRequest("POST", "/target", strings.NewReader("not multipart"))
|
|
req.Header.Set("Content-Type", "multipart/form-data; boundary=bogus")
|
|
|
|
err := parseMultipartForm(context.Background(), req, platform_http.MaxMultipartFormSize)
|
|
require.Error(t, err)
|
|
})
|
|
}
|