mirror of
https://github.com/fleetdm/fleet
synced 2026-05-06 06:48:54 +00:00
Unversion the /setup endpoint, version the websocket endpoint (#5104)
This commit is contained in:
parent
6a5f7172ef
commit
b3fc0cd844
7 changed files with 145 additions and 78 deletions
|
|
@ -975,7 +975,7 @@ Note that live queries are automatically cancelled if this method is not called
|
|||
#### Example script to handle request and response
|
||||
|
||||
```
|
||||
const socket = new WebSocket('wss://<your-base-url>/api/v1/fleet/results/websocket');
|
||||
const socket = new WebSocket('wss://<your-base-url>/api/v1/fleet/results/websockets');
|
||||
|
||||
socket.onopen = () => {
|
||||
socket.send(JSON.stringify({ type: 'auth', data: { token: <auth-token> } }));
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ func (c *Client) LiveQueryWithContext(ctx context.Context, query string, labels
|
|||
if flag.Lookup("test.v") != nil {
|
||||
wssURL.Scheme = "ws"
|
||||
}
|
||||
wssURL.Path = c.urlPrefix + "/api/v1/fleet/results/websocket"
|
||||
wssURL.Path = c.urlPrefix + "/api/latest/fleet/results/websocket"
|
||||
conn, _, err := dialer.Dial(wssURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "upgrade live query result websocket")
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ func TestLiveQueryWithContext(t *testing.T) {
|
|||
}
|
||||
err := json.NewEncoder(w).Encode(resp)
|
||||
assert.NoError(t, err)
|
||||
case "/api/v1/fleet/results/websocket":
|
||||
case "/api/latest/fleet/results/websocket":
|
||||
ws, _ := upgrader.Upgrade(w, r, nil)
|
||||
defer ws.Close()
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
|
|
@ -16,64 +18,96 @@ import (
|
|||
// Stream Distributed Query Campaign Results and Metadata
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func makeStreamDistributedQueryCampaignResultsHandler(svc fleet.Service, logger kitlog.Logger) http.Handler {
|
||||
var reVersion = regexp.MustCompile(`\{fleetversion:\(\?:([^\}\)]+)\)\}`)
|
||||
|
||||
func makeStreamDistributedQueryCampaignResultsHandler(svc fleet.Service, logger kitlog.Logger) func(string) http.Handler {
|
||||
opt := sockjs.DefaultOptions
|
||||
opt.Websocket = true
|
||||
opt.RawWebsocket = true
|
||||
return sockjs.NewHandler("/api/v1/fleet/results", opt, func(session sockjs.Session) {
|
||||
conn := &websocket.Conn{Session: session}
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
logger.Log("err", p, "msg", "panic in result handler")
|
||||
conn.WriteJSONError("panic in result handler")
|
||||
|
||||
return func(path string) http.Handler {
|
||||
// expand the path's versions (with regex) to all literal paths (no regex),
|
||||
// because sockjs requires the (static, literal) path prefix as argument to
|
||||
// create the handler so that it can trim it from the request's URL to get
|
||||
// the special path values (such as the session id).
|
||||
matches := reVersion.FindStringSubmatch(path)
|
||||
if len(matches) == 0 {
|
||||
panic("unexpected path, could not expand fleetversion: " + path)
|
||||
}
|
||||
|
||||
versions := strings.Split(matches[1], "|")
|
||||
literalPaths := make([]string, len(versions))
|
||||
for i, ver := range versions {
|
||||
lp := reVersion.ReplaceAllStringFunc(path, func(_ string) string { return ver })
|
||||
literalPaths[i] = lp
|
||||
}
|
||||
|
||||
sockHandler := func(session sockjs.Session) {
|
||||
conn := &websocket.Conn{Session: session}
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
logger.Log("err", p, "msg", "panic in result handler")
|
||||
conn.WriteJSONError("panic in result handler")
|
||||
}
|
||||
session.Close(0, "none")
|
||||
}()
|
||||
|
||||
// Receive the auth bearer token
|
||||
token, err := conn.ReadAuthToken()
|
||||
if err != nil {
|
||||
logger.Log("err", err, "msg", "failed to read auth token")
|
||||
return
|
||||
}
|
||||
session.Close(0, "none")
|
||||
}()
|
||||
|
||||
// Receive the auth bearer token
|
||||
token, err := conn.ReadAuthToken()
|
||||
if err != nil {
|
||||
logger.Log("err", err, "msg", "failed to read auth token")
|
||||
return
|
||||
// Authenticate with the token
|
||||
vc, err := authViewer(context.Background(), string(token), svc)
|
||||
if err != nil || !vc.CanPerformActions() {
|
||||
logger.Log("err", err, "msg", "unauthorized viewer")
|
||||
conn.WriteJSONError("unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := viewer.NewContext(context.Background(), *vc)
|
||||
|
||||
msg, err := conn.ReadJSONMessage()
|
||||
if err != nil {
|
||||
logger.Log("err", err, "msg", "reading select_campaign JSON")
|
||||
conn.WriteJSONError("error reading select_campaign")
|
||||
return
|
||||
}
|
||||
if msg.Type != "select_campaign" {
|
||||
logger.Log("err", "unexpected msg type, expected select_campaign", "msg-type", msg.Type)
|
||||
conn.WriteJSONError("expected select_campaign")
|
||||
return
|
||||
}
|
||||
|
||||
var info struct {
|
||||
CampaignID uint `json:"campaign_id"`
|
||||
}
|
||||
err = json.Unmarshal(*(msg.Data.(*json.RawMessage)), &info)
|
||||
if err != nil {
|
||||
logger.Log("err", err, "msg", "unmarshaling select_campaign data")
|
||||
conn.WriteJSONError("error unmarshaling select_campaign data")
|
||||
return
|
||||
}
|
||||
if info.CampaignID == 0 {
|
||||
logger.Log("err", "campaign ID not set")
|
||||
conn.WriteJSONError("0 is not a valid campaign ID")
|
||||
return
|
||||
}
|
||||
|
||||
svc.StreamCampaignResults(ctx, conn, info.CampaignID)
|
||||
}
|
||||
|
||||
// Authenticate with the token
|
||||
vc, err := authViewer(context.Background(), string(token), svc)
|
||||
if err != nil || !vc.CanPerformActions() {
|
||||
logger.Log("err", err, "msg", "unauthorized viewer")
|
||||
conn.WriteJSONError("unauthorized")
|
||||
return
|
||||
// multiplex the requests to each literal path that this endpoint support,
|
||||
// with the corresponding sockjs handler to handle that specific path.
|
||||
mux := http.NewServeMux()
|
||||
for _, lp := range literalPaths {
|
||||
// important: sockjs' path must not have the trailing path, but the mux
|
||||
// needs it in order to match it as a path prefix (subtree).
|
||||
sockPath := strings.TrimSuffix(lp, "/")
|
||||
mux.Handle(lp, sockjs.NewHandler(sockPath, opt, sockHandler))
|
||||
}
|
||||
|
||||
ctx := viewer.NewContext(context.Background(), *vc)
|
||||
|
||||
msg, err := conn.ReadJSONMessage()
|
||||
if err != nil {
|
||||
logger.Log("err", err, "msg", "reading select_campaign JSON")
|
||||
conn.WriteJSONError("error reading select_campaign")
|
||||
return
|
||||
}
|
||||
if msg.Type != "select_campaign" {
|
||||
logger.Log("err", "unexpected msg type, expected select_campaign", "msg-type", msg.Type)
|
||||
conn.WriteJSONError("expected select_campaign")
|
||||
return
|
||||
}
|
||||
|
||||
var info struct {
|
||||
CampaignID uint `json:"campaign_id"`
|
||||
}
|
||||
err = json.Unmarshal(*(msg.Data.(*json.RawMessage)), &info)
|
||||
if err != nil {
|
||||
logger.Log("err", err, "msg", "unmarshaling select_campaign data")
|
||||
conn.WriteJSONError("error unmarshaling select_campaign data")
|
||||
return
|
||||
}
|
||||
if info.CampaignID == 0 {
|
||||
logger.Log("err", "campaign ID not set")
|
||||
conn.WriteJSONError("0 is not a valid campaign ID")
|
||||
return
|
||||
}
|
||||
|
||||
svc.StreamCampaignResults(ctx, conn, info.CampaignID)
|
||||
})
|
||||
return mux
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -287,6 +287,7 @@ type authEndpointer struct {
|
|||
endingAtVersion string
|
||||
alternativePaths []string
|
||||
customMiddleware []endpoint.Middleware
|
||||
usePathPrefix bool
|
||||
}
|
||||
|
||||
func newDeviceAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer {
|
||||
|
|
@ -347,22 +348,30 @@ func getNameFromPathAndVerb(verb, path string) string {
|
|||
}
|
||||
|
||||
func (e *authEndpointer) POST(path string, f handlerFunc, v interface{}) {
|
||||
e.handle(path, f, v, "POST")
|
||||
e.handleEndpoint(path, f, v, "POST")
|
||||
}
|
||||
|
||||
func (e *authEndpointer) GET(path string, f handlerFunc, v interface{}) {
|
||||
e.handle(path, f, v, "GET")
|
||||
e.handleEndpoint(path, f, v, "GET")
|
||||
}
|
||||
|
||||
func (e *authEndpointer) PATCH(path string, f handlerFunc, v interface{}) {
|
||||
e.handle(path, f, v, "PATCH")
|
||||
e.handleEndpoint(path, f, v, "PATCH")
|
||||
}
|
||||
|
||||
func (e *authEndpointer) DELETE(path string, f handlerFunc, v interface{}) {
|
||||
e.handle(path, f, v, "DELETE")
|
||||
e.handleEndpoint(path, f, v, "DELETE")
|
||||
}
|
||||
|
||||
func (e *authEndpointer) handle(path string, f handlerFunc, v interface{}, verb string) {
|
||||
// PathHandler registers a handler for the verb and path. The pathHandler is
|
||||
// a function that receives the actual path to which it will be mounted, and
|
||||
// returns the actual http.Handler that will handle this endpoint. This is for
|
||||
// when the handler needs to know on which path it was called.
|
||||
func (e *authEndpointer) PathHandler(verb, path string, pathHandler func(path string) http.Handler) {
|
||||
e.handlePathHandler(path, pathHandler, verb)
|
||||
}
|
||||
|
||||
func (e *authEndpointer) handlePathHandler(path string, pathHandler func(path string) http.Handler, verb string) {
|
||||
versions := e.versions
|
||||
if e.startingAtVersion != "" {
|
||||
startIndex := -1
|
||||
|
|
@ -399,15 +408,32 @@ func (e *authEndpointer) handle(path string, f handlerFunc, v interface{}, verb
|
|||
|
||||
versionedPath := strings.Replace(path, "/_version_/", fmt.Sprintf("/{fleetversion:(?:%s)}/", strings.Join(versions, "|")), 1)
|
||||
nameAndVerb := getNameFromPathAndVerb(verb, path)
|
||||
endpoint := e.makeEndpoint(f, v)
|
||||
e.r.Handle(versionedPath, endpoint).Name(nameAndVerb).Methods(verb)
|
||||
if e.usePathPrefix {
|
||||
e.r.PathPrefix(versionedPath).Handler(pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
|
||||
} else {
|
||||
e.r.Handle(versionedPath, pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
|
||||
}
|
||||
for _, alias := range e.alternativePaths {
|
||||
nameAndVerb := getNameFromPathAndVerb(verb, alias)
|
||||
versionedPath := strings.Replace(alias, "/_version_/", fmt.Sprintf("/{fleetversion:(?:%s)}/", strings.Join(versions, "|")), 1)
|
||||
e.r.Handle(versionedPath, endpoint).Name(nameAndVerb).Methods(verb)
|
||||
if e.usePathPrefix {
|
||||
e.r.PathPrefix(versionedPath).Handler(pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
|
||||
} else {
|
||||
e.r.Handle(versionedPath, pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *authEndpointer) handleHTTPHandler(path string, h http.Handler, verb string) {
|
||||
self := func(_ string) http.Handler { return h }
|
||||
e.handlePathHandler(path, self, verb)
|
||||
}
|
||||
|
||||
func (e *authEndpointer) handleEndpoint(path string, f handlerFunc, v interface{}, verb string) {
|
||||
endpoint := e.makeEndpoint(f, v)
|
||||
e.handleHTTPHandler(path, endpoint, verb)
|
||||
}
|
||||
|
||||
func (e *authEndpointer) makeEndpoint(f handlerFunc, v interface{}) http.Handler {
|
||||
next := func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
return f(ctx, request, e.svc)
|
||||
|
|
@ -446,3 +472,9 @@ func (e *authEndpointer) WithCustomMiddleware(mws ...endpoint.Middleware) *authE
|
|||
ae.customMiddleware = mws
|
||||
return &ae
|
||||
}
|
||||
|
||||
func (e *authEndpointer) UsePathPrefix() *authEndpointer {
|
||||
ae := *e
|
||||
ae.usePathPrefix = true
|
||||
return &ae
|
||||
}
|
||||
|
|
|
|||
|
|
@ -124,15 +124,6 @@ func MakeHandler(
|
|||
r.Use(publicIP)
|
||||
|
||||
attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions, eopts)
|
||||
|
||||
// Results endpoint is handled different due to websockets use
|
||||
|
||||
// TODO: this would not work once v1 is deprecated - note that the handler too uses the /v1/ path
|
||||
// and this routes on path prefix, not exact path (unlike the authendpointer struct).
|
||||
r.PathPrefix("/api/v1/fleet/results/").
|
||||
Handler(makeStreamDistributedQueryCampaignResultsHandler(svc, logger)).
|
||||
Name("distributed_query_results")
|
||||
|
||||
addMetrics(r)
|
||||
|
||||
return r
|
||||
|
|
@ -446,6 +437,12 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
|
|||
ne.POST("/api/v1/fleet/sso", initiateSSOEndpoint, initiateSSORequest{})
|
||||
ne.POST("/api/v1/fleet/sso/callback", makeCallbackSSOEndpoint(config.Server.URLPrefix), callbackSSORequest{})
|
||||
ne.GET("/api/v1/fleet/sso", settingsSSOEndpoint, nil)
|
||||
// the websocket distributed query results endpoint is a bit different - the
|
||||
// provided path is a prefix, not an exact match, and it is not a go-kit
|
||||
// endpoint but a raw http.Handler. It uses the NoAuthEndpointer because
|
||||
// authentication is done when the websocket session is established, inside
|
||||
// the handler.
|
||||
ne.UsePathPrefix().PathHandler("GET", "/api/_version_/fleet/results/", makeStreamDistributedQueryCampaignResultsHandler(svc, logger))
|
||||
|
||||
limiter := ratelimit.NewMiddleware(limitStore)
|
||||
ne.
|
||||
|
|
@ -477,13 +474,16 @@ func WithSetup(svc fleet.Service, logger kitlog.Logger, next http.Handler) http.
|
|||
rxOsquery := regexp.MustCompile(`^/api/[^/]+/osquery`)
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
configRouter := http.NewServeMux()
|
||||
// TODO: hard-codes v1 as a path fragment, which would probably not work once we
|
||||
// deprecate it for newer versions, unless we want to treat the setup differently (not versioned?)
|
||||
configRouter.Handle("/api/v1/setup", kithttp.NewServer(
|
||||
srv := kithttp.NewServer(
|
||||
makeSetupEndpoint(svc, logger),
|
||||
decodeSetupRequest,
|
||||
encodeResponse,
|
||||
))
|
||||
)
|
||||
// NOTE: support setup on both /v1/ and version-less, in the future /v1/
|
||||
// will be dropped.
|
||||
configRouter.Handle("/api/v1/setup", srv)
|
||||
configRouter.Handle("/api/setup", srv)
|
||||
|
||||
// whitelist osqueryd endpoints
|
||||
if rxOsquery.MatchString(r.URL.Path) {
|
||||
next.ServeHTTP(w, r)
|
||||
|
|
|
|||
|
|
@ -92,10 +92,11 @@ func TestStreamCampaignResultsClosesReditOnWSClose(t *testing.T) {
|
|||
_, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}})
|
||||
require.NoError(t, err)
|
||||
|
||||
s := httptest.NewServer(makeStreamDistributedQueryCampaignResultsHandler(svc, kitlog.NewNopLogger()))
|
||||
pathHandler := makeStreamDistributedQueryCampaignResultsHandler(svc, kitlog.NewNopLogger())
|
||||
s := httptest.NewServer(pathHandler("/api/latest/fleet/results/"))
|
||||
defer s.Close()
|
||||
// Convert http://127.0.0.1 to ws://127.0.0.1
|
||||
u := "ws" + strings.TrimPrefix(s.URL, "http") + "/api/v1/fleet/results/websocket"
|
||||
u := "ws" + strings.TrimPrefix(s.URL, "http") + "/api/latest/fleet/results/websocket"
|
||||
|
||||
// Connect to the server
|
||||
dialer := &websocket.Dialer{
|
||||
|
|
|
|||
Loading…
Reference in a new issue