fleet/server/activity/internal/service/endpoint_utils.go
Juan Fernandez 2b35eabd5d
Added middleware for api-only users auth (#43772)
Fixes #42885

Added new middleware (APIOnlyEndpointCheck) that enforces 403 for
API-only users whose request either isn't in the API endpoint catalog or
falls outside their configured per-user endpoint restrictions.
2026-04-21 07:11:33 -04:00

164 lines
5.2 KiB
Go

package service
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"reflect"
"strconv"
"github.com/fleetdm/fleet/v4/server/activity/api"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
eu "github.com/fleetdm/fleet/v4/server/platform/endpointer"
platform_http "github.com/fleetdm/fleet/v4/server/platform/http"
"github.com/go-kit/kit/endpoint"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/gorilla/mux"
)
const (
// defaultPerPage is used when per_page is not specified but page is specified.
defaultPerPage = 20
// maxPerPage is the maximum allowed value for per_page.
maxPerPage = 10000
)
// encodeResponse encodes the response as JSON.
func encodeResponse(ctx context.Context, w http.ResponseWriter, response any) error {
return eu.EncodeCommonResponse(ctx, w, response,
func(w http.ResponseWriter, response any) error {
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
return enc.Encode(response)
},
nil, // no domain-specific error encoder
)
}
// makeDecoder creates a decoder for the given request type.
func makeDecoder(iface any, requestBodySizeLimit int64) kithttp.DecodeRequestFunc {
return eu.MakeDecoder(iface, func(body io.Reader, req any) error {
return json.NewDecoder(body).Decode(req)
}, parseCustomTags, nil, nil, nil, requestBodySizeLimit)
}
// parseCustomTags handles custom URL tag values for activity requests.
func parseCustomTags(urlTagValue string, r *http.Request, field reflect.Value) (bool, error) {
if urlTagValue == "list_options" {
opts, err := listOptionsFromRequest(r)
if err != nil {
return false, err
}
field.Set(reflect.ValueOf(opts))
return true, nil
}
return false, nil
}
// listOptionsFromRequest parses list options from query parameters.
func listOptionsFromRequest(r *http.Request) (api.ListOptions, error) {
q := r.URL.Query()
var page int
if val := q.Get("page"); val != "" {
var err error
page, err = strconv.Atoi(val)
if err != nil {
return api.ListOptions{}, ctxerr.Wrap(r.Context(), &platform_http.BadRequestError{Message: "non-int page value"})
}
if page < 0 {
return api.ListOptions{}, ctxerr.Wrap(r.Context(), &platform_http.BadRequestError{Message: "negative page value"})
}
}
var perPage int
if val := q.Get("per_page"); val != "" {
var err error
perPage, err = strconv.Atoi(val)
if err != nil {
return api.ListOptions{}, ctxerr.Wrap(r.Context(), &platform_http.BadRequestError{Message: "non-int per_page value"})
}
if perPage <= 0 {
return api.ListOptions{}, ctxerr.Wrap(r.Context(), &platform_http.BadRequestError{Message: "invalid per_page value"})
}
if perPage > maxPerPage {
return api.ListOptions{}, ctxerr.Wrap(r.Context(), &platform_http.BadRequestError{
Message: fmt.Sprintf("Request could not be processed. Please set a per_page limit of %d or less", maxPerPage),
})
}
}
orderKey := q.Get("order_key")
orderDirectionString := q.Get("order_direction")
if orderKey == "" && orderDirectionString != "" {
return api.ListOptions{}, ctxerr.Wrap(r.Context(), &platform_http.BadRequestError{Message: "order_key must be specified with order_direction"})
}
var orderDirection api.OrderDirection
switch orderDirectionString {
case "desc":
orderDirection = api.OrderDescending
case "asc", "":
orderDirection = api.OrderAscending
default:
return api.ListOptions{}, ctxerr.Wrap(r.Context(), &platform_http.BadRequestError{Message: "unknown order_direction: " + orderDirectionString})
}
return api.ListOptions{
Page: uint(page), //nolint:gosec // dismiss G115
PerPage: uint(perPage), //nolint:gosec // dismiss G115
OrderKey: orderKey,
OrderDirection: orderDirection,
After: q.Get("after"),
ActivityType: q.Get("activity_type"),
StartCreatedAt: q.Get("start_created_at"),
EndCreatedAt: q.Get("end_created_at"),
MatchQuery: q.Get("query"),
}, nil
}
// handlerFunc is the handler function type for Activity service endpoints.
type handlerFunc func(ctx context.Context, request any, svc api.Service) platform_http.Errorer
// Compile-time check to ensure endpointer implements Endpointer.
var _ eu.Endpointer[handlerFunc] = &endpointer{}
type endpointer struct {
svc api.Service
}
func (e *endpointer) CallHandlerFunc(f handlerFunc, ctx context.Context,
request any,
svc any,
) (platform_http.Errorer, error) {
return f(ctx, request, svc.(api.Service)), nil
}
func (e *endpointer) Service() any {
return e.svc
}
func newUserAuthenticatedEndpointer(svc api.Service, authMiddleware endpoint.Middleware, opts []kithttp.ServerOption, r *mux.Router,
versions ...string,
) *eu.CommonEndpointer[handlerFunc] {
// Append RouteTemplateRequestFunc so the api_only endpoint middleware
// can read the matched mux route template from context.
//
// Full-slice expression prevents aliasing into the caller's backing array
// if it happens to have spare capacity.
opts = append(opts[:len(opts):len(opts)], kithttp.ServerBefore(eu.RouteTemplateRequestFunc))
return &eu.CommonEndpointer[handlerFunc]{
EP: &endpointer{
svc: svc,
},
MakeDecoderFn: makeDecoder,
EncodeFn: encodeResponse,
Opts: opts,
AuthMiddleware: authMiddleware,
Router: r,
Versions: versions,
}
}