diff --git a/docs/Contributing/API-for-contributors.md b/docs/Contributing/API-for-contributors.md index ffd66f3b23..ed3bfe6eb4 100644 --- a/docs/Contributing/API-for-contributors.md +++ b/docs/Contributing/API-for-contributors.md @@ -975,7 +975,7 @@ Note that live queries are automatically cancelled if this method is not called #### Example script to handle request and response ``` -const socket = new WebSocket('wss:///api/v1/fleet/results/websocket'); +const socket = new WebSocket('wss:///api/v1/fleet/results/websockets'); socket.onopen = () => { socket.send(JSON.stringify({ type: 'auth', data: { token: } })); diff --git a/server/service/client_live_query.go b/server/service/client_live_query.go index 81b3c2e2ed..a47656e6bf 100644 --- a/server/service/client_live_query.go +++ b/server/service/client_live_query.go @@ -90,7 +90,7 @@ func (c *Client) LiveQueryWithContext(ctx context.Context, query string, labels if flag.Lookup("test.v") != nil { wssURL.Scheme = "ws" } - wssURL.Path = c.urlPrefix + "/api/v1/fleet/results/websocket" + wssURL.Path = c.urlPrefix + "/api/latest/fleet/results/websocket" conn, _, err := dialer.Dial(wssURL.String(), nil) if err != nil { return nil, ctxerr.Wrap(ctx, err, "upgrade live query result websocket") diff --git a/server/service/client_live_query_test.go b/server/service/client_live_query_test.go index 5e9f8fc0bf..f082ce237b 100644 --- a/server/service/client_live_query_test.go +++ b/server/service/client_live_query_test.go @@ -42,7 +42,7 @@ func TestLiveQueryWithContext(t *testing.T) { } err := json.NewEncoder(w).Encode(resp) assert.NoError(t, err) - case "/api/v1/fleet/results/websocket": + case "/api/latest/fleet/results/websocket": ws, _ := upgrader.Upgrade(w, r, nil) defer ws.Close() diff --git a/server/service/endpoint_campaigns.go b/server/service/endpoint_campaigns.go index d5b3475a18..d0e4fcad61 100644 --- a/server/service/endpoint_campaigns.go +++ b/server/service/endpoint_campaigns.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "net/http" + "regexp" + "strings" "github.com/fleetdm/fleet/v4/server/contexts/viewer" "github.com/fleetdm/fleet/v4/server/fleet" @@ -16,64 +18,96 @@ import ( // Stream Distributed Query Campaign Results and Metadata //////////////////////////////////////////////////////////////////////////////// -func makeStreamDistributedQueryCampaignResultsHandler(svc fleet.Service, logger kitlog.Logger) http.Handler { +var reVersion = regexp.MustCompile(`\{fleetversion:\(\?:([^\}\)]+)\)\}`) + +func makeStreamDistributedQueryCampaignResultsHandler(svc fleet.Service, logger kitlog.Logger) func(string) http.Handler { opt := sockjs.DefaultOptions opt.Websocket = true opt.RawWebsocket = true - return sockjs.NewHandler("/api/v1/fleet/results", opt, func(session sockjs.Session) { - conn := &websocket.Conn{Session: session} - defer func() { - if p := recover(); p != nil { - logger.Log("err", p, "msg", "panic in result handler") - conn.WriteJSONError("panic in result handler") + + return func(path string) http.Handler { + // expand the path's versions (with regex) to all literal paths (no regex), + // because sockjs requires the (static, literal) path prefix as argument to + // create the handler so that it can trim it from the request's URL to get + // the special path values (such as the session id). + matches := reVersion.FindStringSubmatch(path) + if len(matches) == 0 { + panic("unexpected path, could not expand fleetversion: " + path) + } + + versions := strings.Split(matches[1], "|") + literalPaths := make([]string, len(versions)) + for i, ver := range versions { + lp := reVersion.ReplaceAllStringFunc(path, func(_ string) string { return ver }) + literalPaths[i] = lp + } + + sockHandler := func(session sockjs.Session) { + conn := &websocket.Conn{Session: session} + defer func() { + if p := recover(); p != nil { + logger.Log("err", p, "msg", "panic in result handler") + conn.WriteJSONError("panic in result handler") + } + session.Close(0, "none") + }() + + // Receive the auth bearer token + token, err := conn.ReadAuthToken() + if err != nil { + logger.Log("err", err, "msg", "failed to read auth token") + return } - session.Close(0, "none") - }() - // Receive the auth bearer token - token, err := conn.ReadAuthToken() - if err != nil { - logger.Log("err", err, "msg", "failed to read auth token") - return + // Authenticate with the token + vc, err := authViewer(context.Background(), string(token), svc) + if err != nil || !vc.CanPerformActions() { + logger.Log("err", err, "msg", "unauthorized viewer") + conn.WriteJSONError("unauthorized") + return + } + + ctx := viewer.NewContext(context.Background(), *vc) + + msg, err := conn.ReadJSONMessage() + if err != nil { + logger.Log("err", err, "msg", "reading select_campaign JSON") + conn.WriteJSONError("error reading select_campaign") + return + } + if msg.Type != "select_campaign" { + logger.Log("err", "unexpected msg type, expected select_campaign", "msg-type", msg.Type) + conn.WriteJSONError("expected select_campaign") + return + } + + var info struct { + CampaignID uint `json:"campaign_id"` + } + err = json.Unmarshal(*(msg.Data.(*json.RawMessage)), &info) + if err != nil { + logger.Log("err", err, "msg", "unmarshaling select_campaign data") + conn.WriteJSONError("error unmarshaling select_campaign data") + return + } + if info.CampaignID == 0 { + logger.Log("err", "campaign ID not set") + conn.WriteJSONError("0 is not a valid campaign ID") + return + } + + svc.StreamCampaignResults(ctx, conn, info.CampaignID) } - // Authenticate with the token - vc, err := authViewer(context.Background(), string(token), svc) - if err != nil || !vc.CanPerformActions() { - logger.Log("err", err, "msg", "unauthorized viewer") - conn.WriteJSONError("unauthorized") - return + // multiplex the requests to each literal path that this endpoint support, + // with the corresponding sockjs handler to handle that specific path. + mux := http.NewServeMux() + for _, lp := range literalPaths { + // important: sockjs' path must not have the trailing path, but the mux + // needs it in order to match it as a path prefix (subtree). + sockPath := strings.TrimSuffix(lp, "/") + mux.Handle(lp, sockjs.NewHandler(sockPath, opt, sockHandler)) } - - ctx := viewer.NewContext(context.Background(), *vc) - - msg, err := conn.ReadJSONMessage() - if err != nil { - logger.Log("err", err, "msg", "reading select_campaign JSON") - conn.WriteJSONError("error reading select_campaign") - return - } - if msg.Type != "select_campaign" { - logger.Log("err", "unexpected msg type, expected select_campaign", "msg-type", msg.Type) - conn.WriteJSONError("expected select_campaign") - return - } - - var info struct { - CampaignID uint `json:"campaign_id"` - } - err = json.Unmarshal(*(msg.Data.(*json.RawMessage)), &info) - if err != nil { - logger.Log("err", err, "msg", "unmarshaling select_campaign data") - conn.WriteJSONError("error unmarshaling select_campaign data") - return - } - if info.CampaignID == 0 { - logger.Log("err", "campaign ID not set") - conn.WriteJSONError("0 is not a valid campaign ID") - return - } - - svc.StreamCampaignResults(ctx, conn, info.CampaignID) - }) + return mux + } } diff --git a/server/service/endpoint_utils.go b/server/service/endpoint_utils.go index c3f702da72..87e13c580b 100644 --- a/server/service/endpoint_utils.go +++ b/server/service/endpoint_utils.go @@ -287,6 +287,7 @@ type authEndpointer struct { endingAtVersion string alternativePaths []string customMiddleware []endpoint.Middleware + usePathPrefix bool } func newDeviceAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer { @@ -347,22 +348,30 @@ func getNameFromPathAndVerb(verb, path string) string { } func (e *authEndpointer) POST(path string, f handlerFunc, v interface{}) { - e.handle(path, f, v, "POST") + e.handleEndpoint(path, f, v, "POST") } func (e *authEndpointer) GET(path string, f handlerFunc, v interface{}) { - e.handle(path, f, v, "GET") + e.handleEndpoint(path, f, v, "GET") } func (e *authEndpointer) PATCH(path string, f handlerFunc, v interface{}) { - e.handle(path, f, v, "PATCH") + e.handleEndpoint(path, f, v, "PATCH") } func (e *authEndpointer) DELETE(path string, f handlerFunc, v interface{}) { - e.handle(path, f, v, "DELETE") + e.handleEndpoint(path, f, v, "DELETE") } -func (e *authEndpointer) handle(path string, f handlerFunc, v interface{}, verb string) { +// PathHandler registers a handler for the verb and path. The pathHandler is +// a function that receives the actual path to which it will be mounted, and +// returns the actual http.Handler that will handle this endpoint. This is for +// when the handler needs to know on which path it was called. +func (e *authEndpointer) PathHandler(verb, path string, pathHandler func(path string) http.Handler) { + e.handlePathHandler(path, pathHandler, verb) +} + +func (e *authEndpointer) handlePathHandler(path string, pathHandler func(path string) http.Handler, verb string) { versions := e.versions if e.startingAtVersion != "" { startIndex := -1 @@ -399,15 +408,32 @@ func (e *authEndpointer) handle(path string, f handlerFunc, v interface{}, verb versionedPath := strings.Replace(path, "/_version_/", fmt.Sprintf("/{fleetversion:(?:%s)}/", strings.Join(versions, "|")), 1) nameAndVerb := getNameFromPathAndVerb(verb, path) - endpoint := e.makeEndpoint(f, v) - e.r.Handle(versionedPath, endpoint).Name(nameAndVerb).Methods(verb) + if e.usePathPrefix { + e.r.PathPrefix(versionedPath).Handler(pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb) + } else { + e.r.Handle(versionedPath, pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb) + } for _, alias := range e.alternativePaths { nameAndVerb := getNameFromPathAndVerb(verb, alias) versionedPath := strings.Replace(alias, "/_version_/", fmt.Sprintf("/{fleetversion:(?:%s)}/", strings.Join(versions, "|")), 1) - e.r.Handle(versionedPath, endpoint).Name(nameAndVerb).Methods(verb) + if e.usePathPrefix { + e.r.PathPrefix(versionedPath).Handler(pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb) + } else { + e.r.Handle(versionedPath, pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb) + } } } +func (e *authEndpointer) handleHTTPHandler(path string, h http.Handler, verb string) { + self := func(_ string) http.Handler { return h } + e.handlePathHandler(path, self, verb) +} + +func (e *authEndpointer) handleEndpoint(path string, f handlerFunc, v interface{}, verb string) { + endpoint := e.makeEndpoint(f, v) + e.handleHTTPHandler(path, endpoint, verb) +} + func (e *authEndpointer) makeEndpoint(f handlerFunc, v interface{}) http.Handler { next := func(ctx context.Context, request interface{}) (interface{}, error) { return f(ctx, request, e.svc) @@ -446,3 +472,9 @@ func (e *authEndpointer) WithCustomMiddleware(mws ...endpoint.Middleware) *authE ae.customMiddleware = mws return &ae } + +func (e *authEndpointer) UsePathPrefix() *authEndpointer { + ae := *e + ae.usePathPrefix = true + return &ae +} diff --git a/server/service/handler.go b/server/service/handler.go index 5f3c5786c2..4928a3fbae 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -124,15 +124,6 @@ func MakeHandler( r.Use(publicIP) attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions, eopts) - - // Results endpoint is handled different due to websockets use - - // TODO: this would not work once v1 is deprecated - note that the handler too uses the /v1/ path - // and this routes on path prefix, not exact path (unlike the authendpointer struct). - r.PathPrefix("/api/v1/fleet/results/"). - Handler(makeStreamDistributedQueryCampaignResultsHandler(svc, logger)). - Name("distributed_query_results") - addMetrics(r) return r @@ -446,6 +437,12 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC ne.POST("/api/v1/fleet/sso", initiateSSOEndpoint, initiateSSORequest{}) ne.POST("/api/v1/fleet/sso/callback", makeCallbackSSOEndpoint(config.Server.URLPrefix), callbackSSORequest{}) ne.GET("/api/v1/fleet/sso", settingsSSOEndpoint, nil) + // the websocket distributed query results endpoint is a bit different - the + // provided path is a prefix, not an exact match, and it is not a go-kit + // endpoint but a raw http.Handler. It uses the NoAuthEndpointer because + // authentication is done when the websocket session is established, inside + // the handler. + ne.UsePathPrefix().PathHandler("GET", "/api/_version_/fleet/results/", makeStreamDistributedQueryCampaignResultsHandler(svc, logger)) limiter := ratelimit.NewMiddleware(limitStore) ne. @@ -477,13 +474,16 @@ func WithSetup(svc fleet.Service, logger kitlog.Logger, next http.Handler) http. rxOsquery := regexp.MustCompile(`^/api/[^/]+/osquery`) return func(w http.ResponseWriter, r *http.Request) { configRouter := http.NewServeMux() - // TODO: hard-codes v1 as a path fragment, which would probably not work once we - // deprecate it for newer versions, unless we want to treat the setup differently (not versioned?) - configRouter.Handle("/api/v1/setup", kithttp.NewServer( + srv := kithttp.NewServer( makeSetupEndpoint(svc, logger), decodeSetupRequest, encodeResponse, - )) + ) + // NOTE: support setup on both /v1/ and version-less, in the future /v1/ + // will be dropped. + configRouter.Handle("/api/v1/setup", srv) + configRouter.Handle("/api/setup", srv) + // whitelist osqueryd endpoints if rxOsquery.MatchString(r.URL.Path) { next.ServeHTTP(w, r) diff --git a/server/service/service_campaign_test.go b/server/service/service_campaign_test.go index 286c40d233..000b9c79ff 100644 --- a/server/service/service_campaign_test.go +++ b/server/service/service_campaign_test.go @@ -92,10 +92,11 @@ func TestStreamCampaignResultsClosesReditOnWSClose(t *testing.T) { _, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}}) require.NoError(t, err) - s := httptest.NewServer(makeStreamDistributedQueryCampaignResultsHandler(svc, kitlog.NewNopLogger())) + pathHandler := makeStreamDistributedQueryCampaignResultsHandler(svc, kitlog.NewNopLogger()) + s := httptest.NewServer(pathHandler("/api/latest/fleet/results/")) defer s.Close() // Convert http://127.0.0.1 to ws://127.0.0.1 - u := "ws" + strings.TrimPrefix(s.URL, "http") + "/api/v1/fleet/results/websocket" + u := "ws" + strings.TrimPrefix(s.URL, "http") + "/api/latest/fleet/results/websocket" // Connect to the server dialer := &websocket.Dialer{