Unversion the /setup endpoint, version the websocket endpoint (#5104)

This commit is contained in:
Martin Angers 2022-04-20 15:57:26 -04:00 committed by GitHub
parent 6a5f7172ef
commit b3fc0cd844
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 145 additions and 78 deletions

View file

@ -975,7 +975,7 @@ Note that live queries are automatically cancelled if this method is not called
#### Example script to handle request and response
```
const socket = new WebSocket('wss://<your-base-url>/api/v1/fleet/results/websocket');
const socket = new WebSocket('wss://<your-base-url>/api/v1/fleet/results/websockets');
socket.onopen = () => {
socket.send(JSON.stringify({ type: 'auth', data: { token: <auth-token> } }));

View file

@ -90,7 +90,7 @@ func (c *Client) LiveQueryWithContext(ctx context.Context, query string, labels
if flag.Lookup("test.v") != nil {
wssURL.Scheme = "ws"
}
wssURL.Path = c.urlPrefix + "/api/v1/fleet/results/websocket"
wssURL.Path = c.urlPrefix + "/api/latest/fleet/results/websocket"
conn, _, err := dialer.Dial(wssURL.String(), nil)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "upgrade live query result websocket")

View file

@ -42,7 +42,7 @@ func TestLiveQueryWithContext(t *testing.T) {
}
err := json.NewEncoder(w).Encode(resp)
assert.NoError(t, err)
case "/api/v1/fleet/results/websocket":
case "/api/latest/fleet/results/websocket":
ws, _ := upgrader.Upgrade(w, r, nil)
defer ws.Close()

View file

@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"net/http"
"regexp"
"strings"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -16,64 +18,96 @@ import (
// Stream Distributed Query Campaign Results and Metadata
////////////////////////////////////////////////////////////////////////////////
func makeStreamDistributedQueryCampaignResultsHandler(svc fleet.Service, logger kitlog.Logger) http.Handler {
var reVersion = regexp.MustCompile(`\{fleetversion:\(\?:([^\}\)]+)\)\}`)
func makeStreamDistributedQueryCampaignResultsHandler(svc fleet.Service, logger kitlog.Logger) func(string) http.Handler {
opt := sockjs.DefaultOptions
opt.Websocket = true
opt.RawWebsocket = true
return sockjs.NewHandler("/api/v1/fleet/results", opt, func(session sockjs.Session) {
conn := &websocket.Conn{Session: session}
defer func() {
if p := recover(); p != nil {
logger.Log("err", p, "msg", "panic in result handler")
conn.WriteJSONError("panic in result handler")
return func(path string) http.Handler {
// expand the path's versions (with regex) to all literal paths (no regex),
// because sockjs requires the (static, literal) path prefix as argument to
// create the handler so that it can trim it from the request's URL to get
// the special path values (such as the session id).
matches := reVersion.FindStringSubmatch(path)
if len(matches) == 0 {
panic("unexpected path, could not expand fleetversion: " + path)
}
versions := strings.Split(matches[1], "|")
literalPaths := make([]string, len(versions))
for i, ver := range versions {
lp := reVersion.ReplaceAllStringFunc(path, func(_ string) string { return ver })
literalPaths[i] = lp
}
sockHandler := func(session sockjs.Session) {
conn := &websocket.Conn{Session: session}
defer func() {
if p := recover(); p != nil {
logger.Log("err", p, "msg", "panic in result handler")
conn.WriteJSONError("panic in result handler")
}
session.Close(0, "none")
}()
// Receive the auth bearer token
token, err := conn.ReadAuthToken()
if err != nil {
logger.Log("err", err, "msg", "failed to read auth token")
return
}
session.Close(0, "none")
}()
// Receive the auth bearer token
token, err := conn.ReadAuthToken()
if err != nil {
logger.Log("err", err, "msg", "failed to read auth token")
return
// Authenticate with the token
vc, err := authViewer(context.Background(), string(token), svc)
if err != nil || !vc.CanPerformActions() {
logger.Log("err", err, "msg", "unauthorized viewer")
conn.WriteJSONError("unauthorized")
return
}
ctx := viewer.NewContext(context.Background(), *vc)
msg, err := conn.ReadJSONMessage()
if err != nil {
logger.Log("err", err, "msg", "reading select_campaign JSON")
conn.WriteJSONError("error reading select_campaign")
return
}
if msg.Type != "select_campaign" {
logger.Log("err", "unexpected msg type, expected select_campaign", "msg-type", msg.Type)
conn.WriteJSONError("expected select_campaign")
return
}
var info struct {
CampaignID uint `json:"campaign_id"`
}
err = json.Unmarshal(*(msg.Data.(*json.RawMessage)), &info)
if err != nil {
logger.Log("err", err, "msg", "unmarshaling select_campaign data")
conn.WriteJSONError("error unmarshaling select_campaign data")
return
}
if info.CampaignID == 0 {
logger.Log("err", "campaign ID not set")
conn.WriteJSONError("0 is not a valid campaign ID")
return
}
svc.StreamCampaignResults(ctx, conn, info.CampaignID)
}
// Authenticate with the token
vc, err := authViewer(context.Background(), string(token), svc)
if err != nil || !vc.CanPerformActions() {
logger.Log("err", err, "msg", "unauthorized viewer")
conn.WriteJSONError("unauthorized")
return
// multiplex the requests to each literal path that this endpoint support,
// with the corresponding sockjs handler to handle that specific path.
mux := http.NewServeMux()
for _, lp := range literalPaths {
// important: sockjs' path must not have the trailing path, but the mux
// needs it in order to match it as a path prefix (subtree).
sockPath := strings.TrimSuffix(lp, "/")
mux.Handle(lp, sockjs.NewHandler(sockPath, opt, sockHandler))
}
ctx := viewer.NewContext(context.Background(), *vc)
msg, err := conn.ReadJSONMessage()
if err != nil {
logger.Log("err", err, "msg", "reading select_campaign JSON")
conn.WriteJSONError("error reading select_campaign")
return
}
if msg.Type != "select_campaign" {
logger.Log("err", "unexpected msg type, expected select_campaign", "msg-type", msg.Type)
conn.WriteJSONError("expected select_campaign")
return
}
var info struct {
CampaignID uint `json:"campaign_id"`
}
err = json.Unmarshal(*(msg.Data.(*json.RawMessage)), &info)
if err != nil {
logger.Log("err", err, "msg", "unmarshaling select_campaign data")
conn.WriteJSONError("error unmarshaling select_campaign data")
return
}
if info.CampaignID == 0 {
logger.Log("err", "campaign ID not set")
conn.WriteJSONError("0 is not a valid campaign ID")
return
}
svc.StreamCampaignResults(ctx, conn, info.CampaignID)
})
return mux
}
}

View file

@ -287,6 +287,7 @@ type authEndpointer struct {
endingAtVersion string
alternativePaths []string
customMiddleware []endpoint.Middleware
usePathPrefix bool
}
func newDeviceAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer {
@ -347,22 +348,30 @@ func getNameFromPathAndVerb(verb, path string) string {
}
func (e *authEndpointer) POST(path string, f handlerFunc, v interface{}) {
e.handle(path, f, v, "POST")
e.handleEndpoint(path, f, v, "POST")
}
func (e *authEndpointer) GET(path string, f handlerFunc, v interface{}) {
e.handle(path, f, v, "GET")
e.handleEndpoint(path, f, v, "GET")
}
func (e *authEndpointer) PATCH(path string, f handlerFunc, v interface{}) {
e.handle(path, f, v, "PATCH")
e.handleEndpoint(path, f, v, "PATCH")
}
func (e *authEndpointer) DELETE(path string, f handlerFunc, v interface{}) {
e.handle(path, f, v, "DELETE")
e.handleEndpoint(path, f, v, "DELETE")
}
func (e *authEndpointer) handle(path string, f handlerFunc, v interface{}, verb string) {
// PathHandler registers a handler for the verb and path. The pathHandler is
// a function that receives the actual path to which it will be mounted, and
// returns the actual http.Handler that will handle this endpoint. This is for
// when the handler needs to know on which path it was called.
func (e *authEndpointer) PathHandler(verb, path string, pathHandler func(path string) http.Handler) {
e.handlePathHandler(path, pathHandler, verb)
}
func (e *authEndpointer) handlePathHandler(path string, pathHandler func(path string) http.Handler, verb string) {
versions := e.versions
if e.startingAtVersion != "" {
startIndex := -1
@ -399,15 +408,32 @@ func (e *authEndpointer) handle(path string, f handlerFunc, v interface{}, verb
versionedPath := strings.Replace(path, "/_version_/", fmt.Sprintf("/{fleetversion:(?:%s)}/", strings.Join(versions, "|")), 1)
nameAndVerb := getNameFromPathAndVerb(verb, path)
endpoint := e.makeEndpoint(f, v)
e.r.Handle(versionedPath, endpoint).Name(nameAndVerb).Methods(verb)
if e.usePathPrefix {
e.r.PathPrefix(versionedPath).Handler(pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
} else {
e.r.Handle(versionedPath, pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
}
for _, alias := range e.alternativePaths {
nameAndVerb := getNameFromPathAndVerb(verb, alias)
versionedPath := strings.Replace(alias, "/_version_/", fmt.Sprintf("/{fleetversion:(?:%s)}/", strings.Join(versions, "|")), 1)
e.r.Handle(versionedPath, endpoint).Name(nameAndVerb).Methods(verb)
if e.usePathPrefix {
e.r.PathPrefix(versionedPath).Handler(pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
} else {
e.r.Handle(versionedPath, pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
}
}
}
func (e *authEndpointer) handleHTTPHandler(path string, h http.Handler, verb string) {
self := func(_ string) http.Handler { return h }
e.handlePathHandler(path, self, verb)
}
func (e *authEndpointer) handleEndpoint(path string, f handlerFunc, v interface{}, verb string) {
endpoint := e.makeEndpoint(f, v)
e.handleHTTPHandler(path, endpoint, verb)
}
func (e *authEndpointer) makeEndpoint(f handlerFunc, v interface{}) http.Handler {
next := func(ctx context.Context, request interface{}) (interface{}, error) {
return f(ctx, request, e.svc)
@ -446,3 +472,9 @@ func (e *authEndpointer) WithCustomMiddleware(mws ...endpoint.Middleware) *authE
ae.customMiddleware = mws
return &ae
}
func (e *authEndpointer) UsePathPrefix() *authEndpointer {
ae := *e
ae.usePathPrefix = true
return &ae
}

View file

@ -124,15 +124,6 @@ func MakeHandler(
r.Use(publicIP)
attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions, eopts)
// Results endpoint is handled different due to websockets use
// TODO: this would not work once v1 is deprecated - note that the handler too uses the /v1/ path
// and this routes on path prefix, not exact path (unlike the authendpointer struct).
r.PathPrefix("/api/v1/fleet/results/").
Handler(makeStreamDistributedQueryCampaignResultsHandler(svc, logger)).
Name("distributed_query_results")
addMetrics(r)
return r
@ -446,6 +437,12 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
ne.POST("/api/v1/fleet/sso", initiateSSOEndpoint, initiateSSORequest{})
ne.POST("/api/v1/fleet/sso/callback", makeCallbackSSOEndpoint(config.Server.URLPrefix), callbackSSORequest{})
ne.GET("/api/v1/fleet/sso", settingsSSOEndpoint, nil)
// the websocket distributed query results endpoint is a bit different - the
// provided path is a prefix, not an exact match, and it is not a go-kit
// endpoint but a raw http.Handler. It uses the NoAuthEndpointer because
// authentication is done when the websocket session is established, inside
// the handler.
ne.UsePathPrefix().PathHandler("GET", "/api/_version_/fleet/results/", makeStreamDistributedQueryCampaignResultsHandler(svc, logger))
limiter := ratelimit.NewMiddleware(limitStore)
ne.
@ -477,13 +474,16 @@ func WithSetup(svc fleet.Service, logger kitlog.Logger, next http.Handler) http.
rxOsquery := regexp.MustCompile(`^/api/[^/]+/osquery`)
return func(w http.ResponseWriter, r *http.Request) {
configRouter := http.NewServeMux()
// TODO: hard-codes v1 as a path fragment, which would probably not work once we
// deprecate it for newer versions, unless we want to treat the setup differently (not versioned?)
configRouter.Handle("/api/v1/setup", kithttp.NewServer(
srv := kithttp.NewServer(
makeSetupEndpoint(svc, logger),
decodeSetupRequest,
encodeResponse,
))
)
// NOTE: support setup on both /v1/ and version-less, in the future /v1/
// will be dropped.
configRouter.Handle("/api/v1/setup", srv)
configRouter.Handle("/api/setup", srv)
// whitelist osqueryd endpoints
if rxOsquery.MatchString(r.URL.Path) {
next.ServeHTTP(w, r)

View file

@ -92,10 +92,11 @@ func TestStreamCampaignResultsClosesReditOnWSClose(t *testing.T) {
_, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}})
require.NoError(t, err)
s := httptest.NewServer(makeStreamDistributedQueryCampaignResultsHandler(svc, kitlog.NewNopLogger()))
pathHandler := makeStreamDistributedQueryCampaignResultsHandler(svc, kitlog.NewNopLogger())
s := httptest.NewServer(pathHandler("/api/latest/fleet/results/"))
defer s.Close()
// Convert http://127.0.0.1 to ws://127.0.0.1
u := "ws" + strings.TrimPrefix(s.URL, "http") + "/api/v1/fleet/results/websocket"
u := "ws" + strings.TrimPrefix(s.URL, "http") + "/api/latest/fleet/results/websocket"
// Connect to the server
dialer := &websocket.Dialer{