fleet/server/platform/endpointer/endpoint_utils.go
Magnus Jensen d4f48b6f9c
ACME MDM -> main (#42926)
<!-- Add the related story/sub-task/bug number, like Resolves #123, or
remove if NA -->
**Related issue:** The entire ACME feature branch merge

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
See [Changes
files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files)
for more information.

- [x] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements), JS
inline code is prevented especially for url redirects, and untrusted
data interpolated into shell scripts/commands is validated against shell
metacharacters.
- [x] Timeouts are implemented and retries are limited to avoid infinite
loops

## Testing

- [x] Added/updated automated tests
- [x] Where appropriate, [automated tests simulate multiple hosts and
test for host
isolation](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/reference/patterns-backend.md#unit-testing)
(updates to one hosts's records do not affect another)

- [x] QA'd all new/changed functionality manually

---------

Co-authored-by: Jordan Montgomery <elijah.jordan.montgomery@gmail.com>
Co-authored-by: Martin Angers <martin.n.angers@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Gabriel Hernandez <ghernandez345@gmail.com>
Co-authored-by: Sarah Gillespie <73313222+gillespi314@users.noreply.github.com>
2026-04-02 15:56:31 -05:00

1261 lines
41 KiB
Go

package endpointer
import (
"bufio"
"bytes"
"compress/gzip"
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/license"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
platform_logging "github.com/fleetdm/fleet/v4/server/platform/logging"
"github.com/fleetdm/fleet/v4/server/platform/middleware/authzcheck"
"github.com/fleetdm/fleet/v4/server/platform/middleware/ratelimit"
"github.com/go-kit/kit/endpoint"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/gorilla/mux"
)
type HandlerRoutesFunc func(r *mux.Router, opts []kithttp.ServerOption)
// ParseTag parses a `url` tag and whether it's optional or not, which is an optional part of the tag
func ParseTag(tag string) (string, bool, error) {
parts := strings.Split(tag, ",")
switch len(parts) {
case 0:
return "", false, fmt.Errorf("Error parsing %s: too few parts", tag)
case 1:
return tag, false, nil
case 2:
return parts[0], parts[1] == "optional", nil
default:
return "", false, fmt.Errorf("Error parsing %s: too many parts", tag)
}
}
type fieldPair struct {
Sf reflect.StructField
V reflect.Value
}
// allFields returns all the fields for a struct, including the ones from embedded structs
func allFields(ifv reflect.Value) []fieldPair {
if ifv.Kind() == reflect.Ptr {
ifv = ifv.Elem()
}
if ifv.Kind() != reflect.Struct {
return nil
}
var fields []fieldPair
if !ifv.IsValid() {
return nil
}
t := ifv.Type()
for i := 0; i < ifv.NumField(); i++ {
v := ifv.Field(i)
if v.Kind() == reflect.Struct && t.Field(i).Anonymous {
fields = append(fields, allFields(v)...)
continue
}
fields = append(fields, fieldPair{Sf: ifv.Type().Field(i), V: v})
}
return fields
}
// aliasRulesCache caches the result of ExtractAliasRules by reflect.Type so
// that the reflection walk happens only once per struct type, not on every
// request.
var aliasRulesCache sync.Map // reflect.Type → []AliasRule
// ExtractAliasRules inspects the struct type of iface (recursively, including
// embedded structs) and builds an []AliasRule from fields that carry a
// `renameto` struct tag. For each such field the json tag's field name
// becomes OldKey (the current/deprecated name) and the renameto value becomes
// NewKey (the target name).
//
// Only `json` tags are considered; `url` and `query` tags are ignored for now.
//
// The returned slice is deduplicated: if the same alias pair appears on
// multiple fields (e.g. in both a request and an embedded struct) it is
// included only once.
//
// Results are cached by type so that the reflection walk only happens once.
func ExtractAliasRules(iface any) []AliasRule {
if iface == nil {
return nil
}
t := reflect.TypeOf(iface)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil
}
if cached, ok := aliasRulesCache.Load(t); ok {
return cached.([]AliasRule)
}
seen := make(map[AliasRule]bool)
var rules []AliasRule
extractAliasRulesFromType(t, seen, &rules)
aliasRulesCache.Store(t, rules)
return rules
}
func extractAliasRulesFromType(t reflect.Type, seen map[AliasRule]bool, rules *[]AliasRule) {
// visited tracks types we've already walked to avoid infinite recursion
// from cyclic type references (e.g. type Node struct { Children []Node }).
visited := make(map[reflect.Type]bool)
extractAliasRulesRecursive(t, seen, rules, visited)
}
// elemType dereferences pointer, slice, array, and map types to find the
// underlying (possibly struct) element type.
func elemType(t reflect.Type) reflect.Type {
for t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice || t.Kind() == reflect.Array || t.Kind() == reflect.Map {
t = t.Elem()
}
return t
}
// Recursively extract alias rules from the type t.
// This should only be called on struct types.
func extractAliasRulesRecursive(t reflect.Type, seen map[AliasRule]bool, rules *[]AliasRule, visited map[reflect.Type]bool) {
if visited[t] {
return
}
visited[t] = true
for i := 0; i < t.NumField(); i++ {
structField := t.Field(i)
// Check this field for a renameto tag.
renameTo, hasRenameTo := structField.Tag.Lookup("renameto")
if hasRenameTo && renameTo != "" {
jsonTag, hasJSON := structField.Tag.Lookup("json")
if hasJSON && jsonTag != "" && jsonTag != "-" {
// Strip options like ",omitempty" from the json tag.
jsonFieldName, _, _ := strings.Cut(jsonTag, ",")
if jsonFieldName != "" && jsonFieldName != "-" {
rule := AliasRule{OldKey: jsonFieldName, NewKey: renameTo}
if !seen[rule] {
seen[rule] = true
*rules = append(*rules, rule)
}
}
}
}
// Recurse into any struct type reachable from this field
// (through pointers, slices, arrays, maps, or directly).
fieldType := elemType(structField.Type)
if fieldType.Kind() == reflect.Struct {
extractAliasRulesRecursive(fieldType, seen, rules, visited)
}
}
}
func BadRequestErr(publicMsg string, internalErr error) error {
// ensure timeout errors don't become BadRequestErrors.
var opErr *net.OpError
if errors.As(internalErr, &opErr) {
return fmt.Errorf(publicMsg+", internal: %w", internalErr)
}
return &platform_http.BadRequestError{
Message: publicMsg,
InternalErr: internalErr,
}
}
func UintFromRequest(r *http.Request, name string) (uint64, error) {
vars := mux.Vars(r)
s, ok := vars[name]
if !ok {
return 0, ErrBadRoute
}
u, err := strconv.ParseUint(s, 10, 64)
if err != nil {
return 0, ctxerr.Wrap(r.Context(), err, "UintFromRequest")
}
return u, nil
}
func IntFromRequest(r *http.Request, name string) (int64, error) {
vars := mux.Vars(r)
s, ok := vars[name]
if !ok {
return 0, ErrBadRoute
}
u, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0, ctxerr.Wrap(r.Context(), err, "IntFromRequest")
}
return u, nil
}
func StringFromRequest(r *http.Request, name string) (string, error) {
vars := mux.Vars(r)
s, ok := vars[name]
if !ok {
return "", ErrBadRoute
}
unescaped, err := url.PathUnescape(s)
if err != nil {
return "", ctxerr.Wrap(r.Context(), err, "unescape value in path")
}
return unescaped, nil
}
func DecodeURLTagValue(r *http.Request, field reflect.Value, urlTagValue string, optional bool) error {
switch field.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v, err := IntFromRequest(r, urlTagValue)
if err != nil {
if errors.Is(err, ErrBadRoute) && optional {
return nil
}
return BadRequestErr("IntFromRequest", err)
}
field.SetInt(v)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v, err := UintFromRequest(r, urlTagValue)
if err != nil {
if errors.Is(err, ErrBadRoute) && optional {
return nil
}
return BadRequestErr("UintFromRequest", err)
}
field.SetUint(v)
case reflect.String:
v, err := StringFromRequest(r, urlTagValue)
if err != nil {
if errors.Is(err, ErrBadRoute) && optional {
return nil
}
return BadRequestErr("StringFromRequest", err)
}
field.SetString(v)
default:
return fmt.Errorf("unsupported type for field %s for 'url' decoding: %s", urlTagValue, field.Kind())
}
return nil
}
// DomainQueryFieldDecoder decodes a query parameter value into the target field.
// It returns true if it handled the field, false if default handling should be used.
type DomainQueryFieldDecoder func(queryTagName, queryVal string, field reflect.Value) (handled bool, err error)
func DecodeQueryTagValue(r *http.Request, fp fieldPair, customDecoder DomainQueryFieldDecoder, ctx context.Context) error {
queryTagValue, ok := fp.Sf.Tag.Lookup("query")
if ok {
var err error
var optional bool
queryTagValue, optional, err = ParseTag(queryTagValue)
if err != nil {
return err
}
queryVal := r.URL.Query().Get(queryTagValue)
// The query tag now holds the old (deprecated) name. If the old name
// was used, log a deprecation warning. If not found, check the
// renameto value (the new name) as a fallback.
if queryVal != "" {
if renameTo, hasRenameTo := fp.Sf.Tag.Lookup("renameto"); hasRenameTo {
// Check for conflict: if both old and new names are provided, return an error.
newName, _, _ := ParseTag(renameTo)
if newVal := r.URL.Query().Get(newName); newVal != "" {
return &platform_http.BadRequestError{
Message: fmt.Sprintf("Specify only one of %q or %q", queryTagValue, newName),
}
}
// Log deprecation warning - the old name was used.
if platform_logging.TopicEnabled(platform_logging.DeprecatedFieldTopic) {
logging.WithLevel(ctx, slog.LevelWarn)
logging.WithExtras(ctx,
"deprecated_param", queryTagValue,
"deprecation_warning", fmt.Sprintf("'%s' is deprecated, use '%s' instead", queryTagValue, renameTo),
)
}
}
} else if renameTo, hasRenameTo := fp.Sf.Tag.Lookup("renameto"); hasRenameTo {
renameTo, _, err = ParseTag(renameTo)
if err != nil {
return err
}
queryVal = r.URL.Query().Get(renameTo)
}
// If we still don't have a value, return if this is optional, otherwise error.
if queryVal == "" {
if optional {
return nil
}
return &platform_http.BadRequestError{Message: fmt.Sprintf("Param %s is required", queryTagValue)}
}
field := fp.V
if field.Kind() == reflect.Ptr {
// create the new instance of whatever it is
field.Set(reflect.New(field.Type().Elem()))
field = field.Elem()
}
// Try custom decoder first if provided
if customDecoder != nil {
handled, err := customDecoder(queryTagValue, queryVal, field)
if err != nil {
return err
}
if handled {
return nil
}
}
switch field.Kind() {
case reflect.String:
field.SetString(queryVal)
case reflect.Uint:
queryValUint, err := strconv.Atoi(queryVal)
if err != nil {
return BadRequestErr("parsing uint from query", err)
}
field.SetUint(uint64(queryValUint)) //nolint:gosec // dismiss G115
case reflect.Float64:
queryValFloat, err := strconv.ParseFloat(queryVal, 64)
if err != nil {
return BadRequestErr("parsing float from query", err)
}
field.SetFloat(queryValFloat)
case reflect.Bool:
field.SetBool(queryVal == "1" || queryVal == "true")
case reflect.Int:
queryValInt, err := strconv.Atoi(queryVal)
if err != nil {
return BadRequestErr("parsing int from query", err)
}
field.SetInt(int64(queryValInt))
default:
return fmt.Errorf("Cant handle type for field %s %s", fp.Sf.Name, field.Kind())
}
}
return nil
}
// copied from https://github.com/go-chi/chi/blob/c97bc988430d623a14f50b7019fb40529036a35a/middleware/realip.go#L42
var (
trueClientIP = http.CanonicalHeaderKey("True-Client-IP")
xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For")
xRealIP = http.CanonicalHeaderKey("X-Real-IP")
)
func extractIP(r *http.Request) string {
ip := r.RemoteAddr
if i := strings.LastIndexByte(ip, ':'); i != -1 {
ip = ip[:i]
}
// Prefer True-Client-IP and X-Real-IP headers before X-Forwarded-For:
// - True-Client-IP: set by some CDNs (e.g., Akamai) to indicate the real client IP early in the chain
// - X-Real-IP: set by Nginx or similar proxies as a simpler alternative to X-Forwarded-For
// These headers are less likely to be spoofed or malformed compared to X-Forwarded-For.
if tcip := r.Header.Get(trueClientIP); tcip != "" {
ip = tcip
} else if xrip := r.Header.Get(xRealIP); xrip != "" {
ip = xrip
} else if xff := r.Header.Get(xForwardedFor); xff != "" {
// X-Forwarded-For is a comma-separated list of IP addresses representing the chain of proxies
// that a request has passed through. This is not a standard, but a convention.
// The convention is to treat the left-most IP address as the original client IP.
// For example:
// X-Forwarded-For: 198.51.100.1, 203.0.113.5, 127.0.0.1
// Means:
// - 198.51.100.1 is the client IP
// - 127.0.0.1 is the last proxy (likely this server or a local proxy)
//
// If the left-most IP is a private or loopback address (e.g., 127.0.0.1 or 10.x.x.x), it may indicate:
// - The request originated from a local proxy, or
// - The header was spoofed by a client (untrusted source)
//
// Having multiple X-Forwarded-For headers is non-standard, so we do not handle it here.
//
// Here, we grab the left-most IP address by convention.
i := strings.Index(xff, ",")
if i == -1 {
i = len(xff)
}
ip = xff[:i]
}
return ip
}
type ErrorHandler struct {
Logger *slog.Logger
}
func (h *ErrorHandler) Handle(ctx context.Context, err error) {
path, _ := ctx.Value(kithttp.ContextKeyRequestPath).(string)
attrs := []any{"path", path}
if startTime, ok := logging.StartTime(ctx); ok && !startTime.IsZero() {
attrs = append(attrs, "took", time.Since(startTime))
}
var ewi platform_http.ErrWithInternal
if errors.As(err, &ewi) {
attrs = append(attrs, "internal", ewi.Internal())
}
var ewlf platform_http.ErrWithLogFields
if errors.As(err, &ewlf) {
attrs = append(attrs, ewlf.LogFields()...)
}
var uuider platform_http.ErrorUUIDer
if errors.As(err, &uuider) {
attrs = append(attrs, "uuid", uuider.UUID())
}
var rle ratelimit.Error
if errors.As(err, &rle) {
res := rle.Result()
if res.RetryAfter > 0 {
attrs = append(attrs, "retry_after", res.RetryAfter)
}
attrs = append(attrs, "err", "limit exceeded")
} else {
attrs = append(attrs, "err", err)
}
h.Logger.InfoContext(ctx, "request error", attrs...)
}
// A value that implements RequestDecoder takes control of decoding the request
// as a whole - that is, it is responsible for decoding the body and any url
// or query argument itself.
type RequestDecoder interface {
DecodeRequest(ctx context.Context, r *http.Request) (interface{}, error)
}
// A value that implements requestValidator is called after having the values
// decoded into it to apply further validations.
type requestValidator interface {
ValidateRequest() error
}
// MakeDecoder creates a decoder for the type for the struct passed on. If the
// struct has at least 1 json tag it'll unmarshall the body. Custom `url` tag
// values can be handled by providing a parseCustomTags function. Note that
// these behaviors do not work for embedded structs.
//
// Any other `url` tag will be treated as a path variable (of the form
// /path/{name} in the route's path) from the URL path pattern, and it'll be
// decoded and set accordingly. Variables can be optional by setting the tag as
// follows: `url:"some-id,optional"`.
//
// If iface implements the RequestDecoder interface, it returns a function that
// calls iface.DecodeRequest(ctx, r) - i.e. the value itself fully controls its
// own decoding.
//
// If iface implements the bodyDecoder interface, it calls iface.DecodeBody
// after having decoded any non-body fields (such as url and query parameters)
// into the struct.
//
// The customQueryDecoder parameter allows services to inject domain-specific
// query parameter decoding logic.
//
// If adding a new way to parse/decode the request, make sure to wrap the body with http.MaxBytesReader using the maxRequestBodySize
func MakeDecoder(
iface interface{},
jsonUnmarshal func(body io.Reader, req any) error,
parseCustomTags func(urlTagValue string, r *http.Request, field reflect.Value) (bool, error),
isBodyDecoder func(reflect.Value) bool,
decodeBody func(ctx context.Context, r *http.Request, v reflect.Value, body io.Reader) error,
customQueryDecoder DomainQueryFieldDecoder,
maxRequestBodySize int64,
) kithttp.DecodeRequestFunc {
// Infer alias rules from `renameto` struct tags on the request type.
aliasRules := ExtractAliasRules(iface)
if iface == nil {
return func(ctx context.Context, r *http.Request) (interface{}, error) {
return nil, nil
}
}
if rd, ok := iface.(RequestDecoder); ok {
return func(ctx context.Context, r *http.Request) (interface{}, error) {
if maxRequestBodySize != -1 {
r.Body = http.MaxBytesReader(nil, r.Body, maxRequestBodySize)
}
//
// We take care of gzip encoding here to prevent any future DecodeRequest
// implementations from missing gzip bomb checks.
//
gzipped := false
if strings.EqualFold(r.Header.Get("content-encoding"), "gzip") {
gzipped = true
gzr, err := gzip.NewReader(r.Body)
if err != nil {
return nil, BadRequestErr("gzip decoder error", err)
}
defer gzr.Close()
if maxRequestBodySize != -1 {
// Limit decompressed bytes to prevent gzip bombs from bypassing
// the raw body size limit applied above.
r.Body = http.MaxBytesReader(nil, gzr, maxRequestBodySize)
} else {
r.Body = io.NopCloser(gzr)
}
// Clear so implementations don't try to decompress again.
r.Header.Del("Content-Encoding")
}
ret, err := rd.DecodeRequest(ctx, r)
// Some DecodeRequest implementations (like getHostSoftwareRequest)
// themselves return platform_http.PayloadTooLargeError.
if inner, isPayloadTooLargeError := errors.AsType[platform_http.PayloadTooLargeError](err); isPayloadTooLargeError {
// Preserve the inner error's MaxRequestSize and ContentLength
// (it knows the actual limit that was hit), only add Gzipped.
inner.Gzipped = gzipped
return nil, inner
}
// This is the DecodeRequest implementation returning http.MaxBytesError
// (e.g. there's a size limit when uploading installers.)
if _, isMaxBytesError := errors.AsType[*http.MaxBytesError](err); isMaxBytesError {
return nil, platform_http.PayloadTooLargeError{
ContentLength: r.Header.Get("Content-Length"),
MaxRequestSize: maxRequestBodySize,
Gzipped: gzipped,
}
}
return ret, err
}
}
t := reflect.TypeOf(iface)
if t.Kind() != reflect.Struct {
panic(fmt.Sprintf("MakeDecoder only understands structs, not %T", iface))
}
return func(ctx context.Context, r *http.Request) (interface{}, error) {
v := reflect.New(t)
nilBody := false
var rewriter *JSONKeyRewriteReader
if maxRequestBodySize != -1 {
r.Body = http.MaxBytesReader(nil, r.Body, maxRequestBodySize)
}
buf := bufio.NewReader(r.Body)
var body io.Reader = buf
gzipped := false
if _, err := buf.Peek(1); err == io.EOF {
nilBody = true
} else {
if strings.EqualFold(r.Header.Get("content-encoding"), "gzip") {
gzipped = true
gzr, err := gzip.NewReader(buf)
if err != nil {
return nil, BadRequestErr("gzip decoder error", err)
}
defer gzr.Close()
if maxRequestBodySize != -1 {
// Limit decompressed bytes to prevent gzip bombs from bypassing
// the raw body size limit applied above.
body = http.MaxBytesReader(nil, gzr, maxRequestBodySize)
} else {
body = gzr
}
}
// Insert the JSON key rewriter into the reader pipeline
// (after gzip decompression, before JSON decoding) to rename
// deprecated field names and detect alias conflicts.
if len(aliasRules) > 0 {
rewriter = NewJSONKeyRewriteReader(body, aliasRules)
//nolint:errcheck // nothing to do on .Close() error.
defer rewriter.Close()
body = rewriter
}
if isBodyDecoder == nil || !isBodyDecoder(v) {
req := v.Interface()
err := jsonUnmarshal(body, req)
if err != nil {
// Check for alias conflict errors from the rewriter.
var ace *AliasConflictError
if errors.As(err, &ace) {
return nil, &platform_http.BadRequestError{
Message: fmt.Sprintf("Specify only one of %q or %q", ace.Old, ace.New),
InternalErr: ace,
}
}
if _, ok := errors.AsType[*http.MaxBytesError](err); ok {
return nil, platform_http.PayloadTooLargeError{
ContentLength: r.Header.Get("Content-Length"),
MaxRequestSize: maxRequestBodySize,
Gzipped: gzipped,
}
}
return nil, BadRequestErr("json decoder error", err)
}
v = reflect.ValueOf(req)
}
}
fields := allFields(v)
for _, fp := range fields {
field := fp.V
urlTagValue, ok := fp.Sf.Tag.Lookup("url")
var err error
if ok {
optional := false
urlTagValue, optional, err = ParseTag(urlTagValue)
if err != nil {
return nil, err
}
foundValue := false
if parseCustomTags != nil {
foundValue, err = parseCustomTags(urlTagValue, r, field)
if err != nil {
return nil, err
}
}
if !foundValue {
err := DecodeURLTagValue(r, field, urlTagValue, optional)
if err != nil {
return nil, err
}
continue
}
}
_, jsonExpected := fp.Sf.Tag.Lookup("json")
if jsonExpected && nilBody {
return nil, &platform_http.BadRequestError{Message: "Expected JSON Body"}
}
isContentJson := r.Header.Get("Content-Type") == "application/json"
isCrossSite := r.Header.Get("Origin") != "" || r.Header.Get("Referer") != ""
if jsonExpected && isCrossSite && !isContentJson {
return nil, platform_http.NewUserMessageError(errors.New("Expected Content-Type \"application/json\""), http.StatusUnsupportedMediaType)
}
err = DecodeQueryTagValue(r, fp, customQueryDecoder, ctx)
if err != nil {
return nil, err
}
}
if isBodyDecoder != nil && isBodyDecoder(v) {
err := decodeBody(ctx, r, v, body)
if err != nil {
// Check for alias conflict errors from the rewriter.
var ace *AliasConflictError
if errors.As(err, &ace) {
return nil, &platform_http.BadRequestError{
Message: fmt.Sprintf("Specify only one of %q or %q", ace.Old, ace.New),
InternalErr: ace,
}
}
if _, ok := errors.AsType[*http.MaxBytesError](err); ok {
return nil, platform_http.PayloadTooLargeError{
ContentLength: r.Header.Get("Content-Length"),
MaxRequestSize: maxRequestBodySize,
Gzipped: gzipped,
}
}
if errors.Is(err, io.ErrUnexpectedEOF) {
return nil, BadRequestErr("json decoder error", err)
}
return nil, err
}
}
// Log deprecation warnings when deprecated field names are used.
if rewriter != nil && platform_logging.TopicEnabled(platform_logging.DeprecatedFieldTopic) {
if deprecated := rewriter.UsedDeprecatedKeys(); len(deprecated) > 0 {
newNames := make([]string, len(deprecated))
for i, old := range deprecated {
for _, rule := range aliasRules {
if rule.OldKey == old {
newNames[i] = rule.NewKey
break
}
}
}
logging.WithLevel(ctx, slog.LevelWarn)
logging.WithExtras(ctx,
"deprecated_fields", fmt.Sprintf("%v", deprecated),
"deprecation_warning", fmt.Sprintf("use the updated field names (%s) instead", newNames),
)
}
}
if !license.IsPremium(ctx) {
for _, fp := range fields {
if prem, ok := fp.Sf.Tag.Lookup("premium"); ok {
val, err := strconv.ParseBool(prem)
if err != nil {
return nil, err
}
if val && !fp.V.IsZero() {
return nil, &platform_http.BadRequestError{Message: fmt.Sprintf(
"option %s requires a premium license",
fp.Sf.Name,
)}
}
continue
}
}
}
if rv, ok := v.Interface().(requestValidator); ok {
if err := rv.ValidateRequest(); err != nil {
return nil, err
}
}
return v.Interface(), nil
}
}
func newNonce() (string, error) {
b := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func WriteBrowserSecurityHeaders(w http.ResponseWriter, serveCSP, includeNonce bool) (string, error) {
// This endpoint can optionally return a nonce if needed for the Content-Security-Policy header. In general only
// our HTML responses need the nonce, API and other static assets should not include it. We return an empty but
// syntactically valid nonce in the unused case since this will still get substituted into HTML templates when unused
nonce := "disabled"
nonceExtraParam := ""
// generate a unique nonce for this response
if includeNonce {
var err error
nonce, err = newNonce()
if err != nil {
return "", err
}
nonceExtraParam = fmt.Sprintf(" 'nonce-%s'", nonce)
}
// Strict-Transport-Security informs browsers that the site should only be
// accessed using HTTPS, and that any future attempts to access it using
// HTTP should automatically be converted to HTTPS.
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains;")
// X-Frames-Options disallows embedding the UI in other sites via <frame>,
// <iframe>, <embed> or <object>, which can prevent attacks like
// clickjacking.
w.Header().Set("X-Frame-Options", "SAMEORIGIN")
// X-Content-Type-Options prevents browsers from trying to guess the MIME
// type which can cause browsers to transform non-executable content into
// executable content.
w.Header().Set("X-Content-Type-Options", "nosniff")
// Referrer-Policy prevents leaking the origin of the referrer in the
// Referer.
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
if serveCSP {
// TODO Is https OK for img-src? We allow customers to upload their own images and we have to reach out to gravatar for others.
// NB: If default-src ever changes from 'none' make sure to add object-src 'none'
w.Header().Set("Content-Security-Policy", "default-src 'none'; base-uri 'self'; connect-src 'self' www.gravatar.com ws: wss:; img-src 'self' www.gravatar.com data: https:; style-src 'self'"+nonceExtraParam+"; font-src 'self'; script-src 'self'"+nonceExtraParam)
}
return nonce, nil
}
func BrowserSecurityHeadersHandler(serveCSP bool, h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// We don't implement the nonce here in the generic case, however the few endpoints that need it should implement
// their own handling
_, err := WriteBrowserSecurityHeaders(w, serveCSP, false)
if err != nil {
http.Error(w, "failed to write browser security headers", http.StatusInternalServerError)
return
}
h.ServeHTTP(w, r)
})
}
// handlerKey identifies a registered handler by HTTP method and unversioned path template.
type handlerKey struct {
method string
path string // unversioned path template, e.g. "/api/_version_/fleet/fleets"
}
// HandlerRegistry stores HTTP handlers by method+path during endpoint registration,
// enabling lookup for deprecated path alias registration.
type HandlerRegistry struct {
handlers map[handlerKey]http.Handler
}
// NewHandlerRegistry creates an empty HandlerRegistry.
func NewHandlerRegistry() *HandlerRegistry {
return &HandlerRegistry{handlers: make(map[handlerKey]http.Handler)}
}
// DeprecatedPathAlias maps a primary (canonical) path to one or more deprecated
// paths that should serve the same handler.
type DeprecatedPathAlias struct {
Method string
PrimaryPath string // canonical path (must already be registered)
DeprecatedPaths []string // old paths to alias
}
// deprecatedPathInfoKey is the context key for deprecated URL path info.
type deprecatedPathInfoKey struct{}
// deprecatedPathInfo holds the deprecated and canonical paths for logging.
type deprecatedPathInfo struct {
deprecatedPath string
primaryPath string
}
// LogDeprecatedPathAlias is a kithttp.RequestFunc (ServerBefore function)
// that checks if the request is using a deprecated URL path alias and, if so,
// elevates the log level to Warn and adds deprecation info to the request log.
// It must run after the LoggingContext is created (i.e. after SetRequestsContexts).
func LogDeprecatedPathAlias(ctx context.Context, _ *http.Request) context.Context {
if !platform_logging.TopicEnabled(platform_logging.DeprecatedFieldTopic) {
return ctx
}
info, ok := ctx.Value(deprecatedPathInfoKey{}).(deprecatedPathInfo)
if !ok {
return ctx
}
logging.WithLevel(ctx, slog.LevelWarn)
logging.WithExtras(ctx,
"deprecated_path", info.deprecatedPath,
"deprecation_warning", fmt.Sprintf("API `%s` is deprecated, use `%s` instead", info.deprecatedPath, info.primaryPath),
)
return ctx
}
// RegisterDeprecatedPathAliases registers deprecated URL path aliases that point
// to the same handler as the canonical path, and wraps them in a handler that
// can log deprecation warnings.
func RegisterDeprecatedPathAliases(r *mux.Router, versions []string, registry *HandlerRegistry, aliases []DeprecatedPathAlias) {
allVersions := append(append([]string{}, versions...), "latest")
versionRegex := strings.Join(allVersions, "|")
for _, alias := range aliases {
handler := registry.handlers[handlerKey{alias.Method, alias.PrimaryPath}]
if handler == nil {
panic(fmt.Sprintf("deprecated alias: no handler registered for %s %s", alias.Method, alias.PrimaryPath))
}
for _, path := range alias.DeprecatedPaths {
// Replace the version placeholder in the deprecated path with a regex that matches all versions,
// so that the same handler can be used for all versions of the deprecated path.
pathForHandler := strings.Replace(path, "/_version_/", fmt.Sprintf("/{fleetversion:(?:%s)}/", versionRegex), 1)
info := deprecatedPathInfo{deprecatedPath: path, primaryPath: alias.PrimaryPath}
// Wrap the handler to inject deprecation info into the context for logging.
wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), deprecatedPathInfoKey{}, info)
handler.ServeHTTP(w, r.WithContext(ctx))
})
nameAndVerb := getNameFromPathAndVerb(alias.Method, path, "")
r.Handle(pathForHandler, wrappedHandler).Name(nameAndVerb).Methods(alias.Method)
}
}
}
type CommonEndpointer[H any] struct {
EP Endpointer[H]
MakeDecoderFn func(iface any, requestBodyLimit int64) kithttp.DecodeRequestFunc
EncodeFn kithttp.EncodeResponseFunc
Opts []kithttp.ServerOption
Router *mux.Router
Versions []string
// AuthMiddleware is a pre-built authentication middleware.
AuthMiddleware endpoint.Middleware
// CustomMiddleware are middlewares that run before authentication.
CustomMiddleware []endpoint.Middleware
// CustomMiddlewareAfterAuth are middlewares that run after authentication.
CustomMiddlewareAfterAuth []endpoint.Middleware
// HandlerRegistry, if set, records handlers by method+path for deprecated
// path alias lookup. The pointer is shared across shallow copies (created
// by builder methods like WithAltPaths) so all registrations land in the
// same map.
HandlerRegistry *HandlerRegistry
startingAtVersion string
endingAtVersion string
alternativePaths []string
usePathPrefix bool
// The limit of the request body size in bytes, if set to -1 there is no limit.
requestBodySizeLimit int64
}
type Endpointer[H any] interface {
CallHandlerFunc(f H, ctx context.Context, request any, svc any) (platform_http.Errorer, error)
Service() any
}
func (e *CommonEndpointer[H]) POST(path string, f H, v interface{}) {
e.handleEndpoint(path, f, v, "POST")
}
func (e *CommonEndpointer[H]) GET(path string, f H, v interface{}) {
e.handleEndpoint(path, f, v, "GET")
}
func (e *CommonEndpointer[H]) PUT(path string, f H, v interface{}) {
e.handleEndpoint(path, f, v, "PUT")
}
func (e *CommonEndpointer[H]) PATCH(path string, f H, v interface{}) {
e.handleEndpoint(path, f, v, "PATCH")
}
func (e *CommonEndpointer[H]) DELETE(path string, f H, v interface{}) {
e.handleEndpoint(path, f, v, "DELETE")
}
func (e *CommonEndpointer[H]) HEAD(path string, f H, v interface{}) {
e.handleEndpoint(path, f, v, "HEAD")
}
func (e *CommonEndpointer[H]) handleEndpoint(path string, f H, v interface{}, verb string) {
endpoint := e.makeEndpoint(f, v)
e.HandleHTTPHandler(path, endpoint, verb)
}
func (e *CommonEndpointer[H]) makeEndpoint(f H, v interface{}) http.Handler {
next := func(ctx context.Context, request interface{}) (interface{}, error) {
return e.EP.CallHandlerFunc(f, ctx, request, e.EP.Service())
}
// Apply "after auth" middleware (in reverse order so that the first wraps
// the second wraps the third etc.)
endp := next
if len(e.CustomMiddlewareAfterAuth) > 0 {
for i := len(e.CustomMiddlewareAfterAuth) - 1; i >= 0; i-- {
mw := e.CustomMiddlewareAfterAuth[i]
endp = mw(endp)
}
}
if e.AuthMiddleware == nil {
// This panic catches potential security issues during development.
panic("AuthMiddleware must be set on CommonEndpointer")
}
endp = e.AuthMiddleware(endp)
// Apply "before auth" middleware (in reverse order so that the first wraps
// the second wraps the third etc.)
for i := len(e.CustomMiddleware) - 1; i >= 0; i-- {
mw := e.CustomMiddleware[i]
endp = mw(endp)
}
// Default to MaxRequestBodySize if no limit is set, this ensures no endpointers are forgot
// -1 = no limit, so don't default to anything if that is set, which can only be set with the appropriate SKIP method.
if e.requestBodySizeLimit != -1 && (e.requestBodySizeLimit == 0 || e.requestBodySizeLimit < platform_http.MaxRequestBodySize) {
// If no value is configured set default, or if the set endpoint value is less than global default use default.
e.requestBodySizeLimit = platform_http.MaxRequestBodySize
}
return newServer(endp, e.MakeDecoderFn(v, e.requestBodySizeLimit), e.EncodeFn, e.Opts)
}
func newServer(e endpoint.Endpoint, decodeFn kithttp.DecodeRequestFunc, encodeFn kithttp.EncodeResponseFunc,
opts []kithttp.ServerOption,
) http.Handler {
// TODO: some handlers don't have authz checks, and because the SkipAuth call is done only in the
// endpoint handler, any middleware that raises errors before the handler is reached will end up
// returning authz check missing instead of the more relevant error. Should be addressed as part
// of #4406.
e = authzcheck.NewMiddleware().AuthzCheck()(e)
return kithttp.NewServer(e, decodeFn, encodeFn, opts...)
}
func (e *CommonEndpointer[H]) StartingAtVersion(version string) *CommonEndpointer[H] {
ae := *e
ae.startingAtVersion = version
return &ae
}
func (e *CommonEndpointer[H]) EndingAtVersion(version string) *CommonEndpointer[H] {
ae := *e
ae.endingAtVersion = version
return &ae
}
func (e *CommonEndpointer[H]) WithAltPaths(paths ...string) *CommonEndpointer[H] {
ae := *e
ae.alternativePaths = paths
return &ae
}
func (e *CommonEndpointer[H]) WithCustomMiddleware(mws ...endpoint.Middleware) *CommonEndpointer[H] {
ae := *e
ae.CustomMiddleware = mws
return &ae
}
func (e *CommonEndpointer[H]) AppendCustomMiddleware(mws ...endpoint.Middleware) *CommonEndpointer[H] {
ae := *e
ae.CustomMiddleware = append(ae.CustomMiddleware, mws...)
return &ae
}
func (e *CommonEndpointer[H]) WithCustomMiddlewareAfterAuth(mws ...endpoint.Middleware) *CommonEndpointer[H] {
ae := *e
ae.CustomMiddlewareAfterAuth = mws
return &ae
}
func (e *CommonEndpointer[H]) UsePathPrefix() *CommonEndpointer[H] {
ae := *e
ae.usePathPrefix = true
return &ae
}
func (e *CommonEndpointer[H]) WithRequestBodySizeLimit(limit int64) *CommonEndpointer[H] {
ae := *e
if limit > 0 {
// Only set it when the limit is more than 0
ae.requestBodySizeLimit = limit
}
return &ae
}
func (e *CommonEndpointer[H]) SkipRequestBodySizeLimit() *CommonEndpointer[H] {
ae := *e
ae.requestBodySizeLimit = -1
return &ae
}
// PathHandler registers a handler for the verb and path. The pathHandler is
// a function that receives the actual path to which it will be mounted, and
// returns the actual http.Handler that will handle this endpoint. This is for
// when the handler needs to know on which path it was called.
func (e *CommonEndpointer[H]) PathHandler(verb, path string, pathHandler func(path string) http.Handler) {
e.HandlePathHandler(path, pathHandler, verb)
}
func (e *CommonEndpointer[H]) HandleHTTPHandler(path string, h http.Handler, verb string) {
self := func(_ string) http.Handler { return h }
e.HandlePathHandler(path, self, verb)
}
var pathReplacer = strings.NewReplacer(
"/", "_",
"{", "_",
"}", "_",
)
func getNameFromPathAndVerb(verb, path, startAt string) string {
prefix := strings.ToLower(verb) + "_"
if startAt != "" {
prefix += pathReplacer.Replace(startAt) + "_"
}
return prefix + pathReplacer.Replace(strings.TrimPrefix(strings.TrimRight(path, "/"), "/api/_version_/fleet/"))
}
func (e *CommonEndpointer[H]) HandlePathHandler(path string, pathHandler func(path string) http.Handler, verb string) {
versions := e.Versions
if e.startingAtVersion != "" {
startIndex := -1
for i, version := range versions {
if version == e.startingAtVersion {
startIndex = i
break
}
}
if startIndex == -1 {
panic("StartAtVersion is not part of the valid versions")
}
versions = versions[startIndex:]
}
if e.endingAtVersion != "" {
endIndex := -1
for i, version := range versions {
if version == e.endingAtVersion {
endIndex = i
break
}
}
if endIndex == -1 {
panic("EndAtVersion is not part of the valid versions")
}
versions = versions[:endIndex+1]
}
// if a version doesn't have a deprecation version, or the ending version is the latest one, then it's part of the
// latest
if e.endingAtVersion == "" || e.endingAtVersion == e.Versions[len(e.Versions)-1] {
versions = append(versions, "latest")
}
versionedPath := strings.Replace(path, "/_version_/", fmt.Sprintf("/{fleetversion:(?:%s)}/", strings.Join(versions, "|")), 1)
nameAndVerb := getNameFromPathAndVerb(verb, path, e.startingAtVersion)
handler := pathHandler(versionedPath)
if e.usePathPrefix {
e.Router.PathPrefix(versionedPath).Handler(handler).Name(nameAndVerb).Methods(verb)
} else {
e.Router.Handle(versionedPath, handler).Name(nameAndVerb).Methods(verb)
}
if e.HandlerRegistry != nil {
e.HandlerRegistry.handlers[handlerKey{verb, path}] = handler
}
for _, alias := range e.alternativePaths {
nameAndVerb := getNameFromPathAndVerb(verb, alias, e.startingAtVersion)
versionedPath := strings.Replace(alias, "/_version_/", fmt.Sprintf("/{fleetversion:(?:%s)}/", strings.Join(versions, "|")), 1)
if e.usePathPrefix {
e.Router.PathPrefix(versionedPath).Handler(pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
} else {
e.Router.Handle(versionedPath, pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
}
}
}
func EncodeCommonResponse(
ctx context.Context,
w http.ResponseWriter,
response interface{},
jsonMarshal func(w http.ResponseWriter, response interface{}) error,
domainErrorEncoder DomainErrorEncoder,
) error {
// Infer alias rules from `renameto` struct tags on the response type.
aliasRules := ExtractAliasRules(response)
if br, ok := response.(beforeRenderer); ok {
br.BeforeRender(ctx, w)
}
if cs, ok := response.(cookieSetter); ok {
cs.SetCookies(ctx, w)
}
// The has to happen first, if an error happens we'll redirect to an error
// page and the error will be logged
if page, ok := response.(htmlPage); ok {
w.Header().Set("Content-Type", "text/html; charset=UTF-8")
// This will not return an error if disabled
_, _ = WriteBrowserSecurityHeaders(w, false, false)
if coder, ok := page.Error().(kithttp.StatusCoder); ok {
w.WriteHeader(coder.StatusCode())
}
_, err := io.WriteString(w, page.Html())
return err
}
if e, ok := response.(platform_http.Errorer); ok && e.Error() != nil {
EncodeError(ctx, e.Error(), w, domainErrorEncoder)
return nil
}
if render, ok := response.(renderHijacker); ok {
render.HijackRender(ctx, w)
return nil
}
if e, ok := response.(statuser); ok {
w.WriteHeader(e.Status())
if e.Status() == http.StatusNoContent {
return nil
}
}
// If alias rules are configured, buffer the JSON output so we can
// duplicate keys (old→new) for forwards compatibility before writing
// to the response.
if len(aliasRules) > 0 {
var buf bytes.Buffer
bufWriter := &bufferedResponseWriter{ResponseWriter: w, buf: &buf}
if err := jsonMarshal(bufWriter, response); err != nil {
return err
}
transformed := DuplicateJSONKeys(buf.Bytes(), aliasRules)
_, err := w.Write(transformed)
return err
}
return jsonMarshal(w, response)
}
// statuser allows response types to implement a custom
// http success status - default is 200 OK
type statuser interface {
Status() int
}
// loads a html page
type htmlPage interface {
Html() string
Error() error
}
// renderHijacker can be implemented by response values to take control of
// their own rendering.
type renderHijacker interface {
HijackRender(ctx context.Context, w http.ResponseWriter)
}
// cookieSetter can be implemented by response values to set cookies on the response.
type cookieSetter interface {
SetCookies(ctx context.Context, w http.ResponseWriter)
}
// beforeRenderer can be implemented by response values that need to hook into the
// raw rendering process, with access to the ResponseWriter before any response is
// written, while continuing with the normal rendering process after the call.
// It can be used to set headers, for example, and since the processing happens before
// any Errorer check, it can also be used to fail the request by storing an error on
// the Errorer. It should not set the status code of the response, as the standard
// approach of implementing the statuser interface should be used for that.
//
// Unlike renderHijacker and the htmlPage interfaces, this interface does not stop
// processing, and while it behaves similarly to cookieSetter, it is more generally-named
// and does not have the specific connotation of setting cookies.
type beforeRenderer interface {
BeforeRender(ctx context.Context, w http.ResponseWriter)
}
// bufferedResponseWriter wraps an http.ResponseWriter but redirects Write
// calls to a bytes.Buffer, allowing the output to be captured and
// transformed before being sent to the real writer. It implements
// http.ResponseWriter so it can be passed to jsonMarshal functions.
type bufferedResponseWriter struct {
http.ResponseWriter
buf *bytes.Buffer
}
func (b *bufferedResponseWriter) Write(data []byte) (int, error) {
return b.buf.Write(data)
}