diff --git a/server/service/carves.go b/server/service/carves.go index e5e659d850..dfc9049e25 100644 --- a/server/service/carves.go +++ b/server/service/carves.go @@ -4,9 +4,12 @@ import ( "context" "errors" "fmt" + "time" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" + hostctx "github.com/fleetdm/fleet/v4/server/contexts/host" "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/google/uuid" ) //////////////////////////////////////////////////////////////////////////////// @@ -131,3 +134,101 @@ func (svc *Service) GetBlock(ctx context.Context, carveId, blockId int64) ([]byt return data, nil } + +//////////////////////////////////////////////////////////////////////////////// +// Begin File Carve +//////////////////////////////////////////////////////////////////////////////// + +type carveBeginRequest struct { + NodeKey string `json:"node_key"` + BlockCount int64 `json:"block_count"` + BlockSize int64 `json:"block_size"` + CarveSize int64 `json:"carve_size"` + CarveId string `json:"carve_id"` + RequestId string `json:"request_id"` +} + +type carveBeginResponse struct { + SessionId string `json:"session_id"` + Success bool `json:"success,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r carveBeginResponse) error() error { return r.Err } + +func carveBeginEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*carveBeginRequest) + + payload := fleet.CarveBeginPayload{ + BlockCount: req.BlockCount, + BlockSize: req.BlockSize, + CarveSize: req.CarveSize, + CarveId: req.CarveId, + RequestId: req.RequestId, + } + + carve, err := svc.CarveBegin(ctx, payload) + if err != nil { + return carveBeginResponse{Err: err}, nil + } + + return carveBeginResponse{SessionId: carve.SessionId, Success: true}, nil +} + +const ( + maxCarveSize = 8 * 1024 * 1024 * 1024 // 8GB + maxBlockSize = 256 * 1024 * 1024 // 256MB +) + +func (svc *Service) CarveBegin(ctx context.Context, payload fleet.CarveBeginPayload) (*fleet.CarveMetadata, error) { + // skipauth: Authorization is currently for user endpoints only. + svc.authz.SkipAuthorization(ctx) + + host, ok := hostctx.FromContext(ctx) + if !ok { + return nil, osqueryError{message: "internal error: missing host from request context"} + } + + if payload.CarveSize == 0 { + return nil, osqueryError{message: "carve_size must be greater than 0"} + } + + if payload.BlockSize > maxBlockSize { + return nil, osqueryError{message: "block_size exceeds max"} + } + if payload.CarveSize > maxCarveSize { + return nil, osqueryError{message: "carve_size exceeds max"} + } + + // The carve should have a total size that fits appropriately into the + // number of blocks of the specified size. + if payload.CarveSize <= (payload.BlockCount-1)*payload.BlockSize || + payload.CarveSize > payload.BlockCount*payload.BlockSize { + return nil, osqueryError{message: "carve_size does not match block_size and block_count"} + } + + sessionId, err := uuid.NewRandom() + if err != nil { + return nil, osqueryError{message: "internal error: generate session ID for carve: " + err.Error()} + } + + now := time.Now().UTC() + carve := &fleet.CarveMetadata{ + Name: fmt.Sprintf("%s-%s-%s", host.Hostname, now.Format(time.RFC3339), payload.RequestId), + HostId: host.ID, + BlockCount: payload.BlockCount, + BlockSize: payload.BlockSize, + CarveSize: payload.CarveSize, + CarveId: payload.CarveId, + RequestId: payload.RequestId, + SessionId: sessionId.String(), + CreatedAt: now, + } + + carve, err = svc.carveStore.NewCarve(ctx, carve) + if err != nil { + return nil, osqueryError{message: "internal error: new carve: " + err.Error()} + } + + return carve, nil +} diff --git a/server/service/endpoint_carves.go b/server/service/endpoint_carves.go index ec94fc1694..46f58c90a7 100644 --- a/server/service/endpoint_carves.go +++ b/server/service/endpoint_carves.go @@ -7,48 +7,6 @@ import ( "github.com/go-kit/kit/endpoint" ) -//////////////////////////////////////////////////////////////////////////////// -// Begin File Carve -//////////////////////////////////////////////////////////////////////////////// - -type carveBeginRequest struct { - NodeKey string `json:"node_key"` - BlockCount int64 `json:"block_count"` - BlockSize int64 `json:"block_size"` - CarveSize int64 `json:"carve_size"` - CarveId string `json:"carve_id"` - RequestId string `json:"request_id"` -} - -type carveBeginResponse struct { - SessionId string `json:"session_id"` - Success bool `json:"success,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r carveBeginResponse) error() error { return r.Err } - -func makeCarveBeginEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(carveBeginRequest) - - payload := fleet.CarveBeginPayload{ - BlockCount: req.BlockCount, - BlockSize: req.BlockSize, - CarveSize: req.CarveSize, - CarveId: req.CarveId, - RequestId: req.RequestId, - } - - carve, err := svc.CarveBegin(ctx, payload) - if err != nil { - return carveBeginResponse{Err: err}, nil - } - - return carveBeginResponse{SessionId: carve.SessionId, Success: true}, nil - } -} - //////////////////////////////////////////////////////////////////////////////// // Receive Block for File Carve //////////////////////////////////////////////////////////////////////////////// diff --git a/server/service/endpoint_middleware.go b/server/service/endpoint_middleware.go index 428339d3c1..bdf7d50790 100644 --- a/server/service/endpoint_middleware.go +++ b/server/service/endpoint_middleware.go @@ -78,6 +78,9 @@ func getNodeKey(r interface{}) (string, error) { // Retrieve node key by reflection (note that our options here // are limited by the fact that request is an interface{}) v := reflect.ValueOf(r) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } if v.Kind() != reflect.Struct { return "", osqueryError{ message: "request type is not struct. This is likely a Fleet programmer error.", diff --git a/server/service/endpoint_osquery.go b/server/service/endpoint_osquery.go index 7b78372fca..c0ed9a9211 100644 --- a/server/service/endpoint_osquery.go +++ b/server/service/endpoint_osquery.go @@ -2,7 +2,6 @@ package service import ( "context" - "encoding/json" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/go-kit/kit/endpoint" @@ -35,138 +34,3 @@ func makeEnrollAgentEndpoint(svc fleet.Service) endpoint.Endpoint { return enrollAgentResponse{NodeKey: nodeKey}, nil } } - -//////////////////////////////////////////////////////////////////////////////// -// Get Client Config -//////////////////////////////////////////////////////////////////////////////// - -type getClientConfigRequest struct { - NodeKey string `json:"node_key"` -} - -type getClientConfigResponse struct { - Config map[string]interface{} - Err error `json:"error,omitempty"` -} - -func (r getClientConfigResponse) error() error { return r.Err } - -func makeGetClientConfigEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - config, err := svc.GetClientConfig(ctx) - if err != nil { - return getClientConfigResponse{Err: err}, nil - } - - // We return the config here explicitly because osquery exepects the - // response for configs to be at the top-level of the JSON response - return config, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Get Distributed Queries -//////////////////////////////////////////////////////////////////////////////// - -type getDistributedQueriesRequest struct { - NodeKey string `json:"node_key"` -} - -type getDistributedQueriesResponse struct { - Queries map[string]string `json:"queries"` - Accelerate uint `json:"accelerate,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r getDistributedQueriesResponse) error() error { return r.Err } - -func makeGetDistributedQueriesEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - queries, accelerate, err := svc.GetDistributedQueries(ctx) - if err != nil { - return getDistributedQueriesResponse{Err: err}, nil - } - return getDistributedQueriesResponse{Queries: queries, Accelerate: accelerate}, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Write Distributed Query Results -//////////////////////////////////////////////////////////////////////////////// - -type SubmitDistributedQueryResultsRequest struct { - NodeKey string `json:"node_key"` - Results fleet.OsqueryDistributedQueryResults `json:"queries"` - Statuses map[string]fleet.OsqueryStatus `json:"statuses"` - Messages map[string]string `json:"messages"` -} - -type submitDistributedQueryResultsResponse struct { - Err error `json:"error,omitempty"` -} - -func (r submitDistributedQueryResultsResponse) error() error { return r.Err } - -func makeSubmitDistributedQueryResultsEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(SubmitDistributedQueryResultsRequest) - err := svc.SubmitDistributedQueryResults(ctx, req.Results, req.Statuses, req.Messages) - if err != nil { - return submitDistributedQueryResultsResponse{Err: err}, nil - } - return submitDistributedQueryResultsResponse{}, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Submit Logs -//////////////////////////////////////////////////////////////////////////////// - -type submitLogsRequest struct { - NodeKey string `json:"node_key"` - LogType string `json:"log_type"` - Data json.RawMessage `json:"data"` -} - -type submitLogsResponse struct { - Err error `json:"error,omitempty"` -} - -func (r submitLogsResponse) error() error { return r.Err } - -func makeSubmitLogsEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(submitLogsRequest) - - var err error - switch req.LogType { - case "status": - var statuses []json.RawMessage - if err := json.Unmarshal(req.Data, &statuses); err != nil { - err = osqueryError{message: "unmarshalling status logs: " + err.Error()} - break - } - - err = svc.SubmitStatusLogs(ctx, statuses) - if err != nil { - break - } - - case "result": - var results []json.RawMessage - if err := json.Unmarshal(req.Data, &results); err != nil { - err = osqueryError{message: "unmarshalling result logs: " + err.Error()} - break - } - err = svc.SubmitResultLogs(ctx, results) - if err != nil { - break - } - - default: - err = osqueryError{message: "unknown log type: " + req.LogType} - } - - return submitLogsResponse{Err: err}, nil - } -} diff --git a/server/service/endpoint_utils.go b/server/service/endpoint_utils.go index 9b7db84781..88ffb079da 100644 --- a/server/service/endpoint_utils.go +++ b/server/service/endpoint_utils.go @@ -2,6 +2,7 @@ package service import ( "bufio" + "compress/gzip" "context" "encoding/json" "errors" @@ -13,6 +14,8 @@ import ( "strings" "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" kithttp "github.com/go-kit/kit/transport/http" "github.com/gorilla/mux" ) @@ -95,8 +98,18 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc { if _, err := buf.Peek(1); err == io.EOF { nilBody = true } else { + var body io.Reader = buf + if r.Header.Get("content-encoding") == "gzip" { + gzr, err := gzip.NewReader(buf) + if err != nil { + return nil, err + } + defer gzr.Close() + body = gzr + } + req := v.Interface() - if err := json.NewDecoder(buf).Decode(req); err != nil { + if err := json.NewDecoder(body).Decode(req); err != nil { return nil, err } v = reflect.ValueOf(req) @@ -250,18 +263,38 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc { } } -type UserAuthEndpointer struct { +type authEndpointer struct { svc fleet.Service opts []kithttp.ServerOption r *mux.Router + authFunc func(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint versions []string startingAtVersion string endingAtVersion string alternativePaths []string } -func NewUserAuthenticatedEndpointer(svc fleet.Service, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *UserAuthEndpointer { - return &UserAuthEndpointer{svc: svc, opts: opts, r: r, versions: versions} +func newUserAuthenticatedEndpointer(svc fleet.Service, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer { + return &authEndpointer{ + svc: svc, + opts: opts, + r: r, + authFunc: authenticatedUser, + versions: versions, + } +} + +func newHostAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer { + authFunc := func(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint { + return authenticatedHost(svc, logger, next) + } + return &authEndpointer{ + svc: svc, + opts: opts, + r: r, + authFunc: authFunc, + versions: versions, + } } var pathReplacer = strings.NewReplacer( @@ -272,26 +305,26 @@ var pathReplacer = strings.NewReplacer( func getNameFromPathAndVerb(verb, path string) string { return strings.ToLower(verb) + "_" + - pathReplacer.Replace(strings.TrimPrefix(strings.TrimRight(path, "/"), "/api/v1/fleet/")) + pathReplacer.Replace(strings.TrimPrefix(strings.TrimRight(path, "/"), "/api/_version_/fleet/")) } -func (e *UserAuthEndpointer) POST(path string, f handlerFunc, v interface{}) { +func (e *authEndpointer) POST(path string, f handlerFunc, v interface{}) { e.handle(path, f, v, "POST") } -func (e *UserAuthEndpointer) GET(path string, f handlerFunc, v interface{}) { +func (e *authEndpointer) GET(path string, f handlerFunc, v interface{}) { e.handle(path, f, v, "GET") } -func (e *UserAuthEndpointer) PATCH(path string, f handlerFunc, v interface{}) { +func (e *authEndpointer) PATCH(path string, f handlerFunc, v interface{}) { e.handle(path, f, v, "PATCH") } -func (e *UserAuthEndpointer) DELETE(path string, f handlerFunc, v interface{}) { +func (e *authEndpointer) DELETE(path string, f handlerFunc, v interface{}) { e.handle(path, f, v, "DELETE") } -func (e *UserAuthEndpointer) handle(path string, f handlerFunc, v interface{}, verb string) { +func (e *authEndpointer) handle(path string, f handlerFunc, v interface{}, verb string) { versions := e.versions if e.startingAtVersion != "" { startIndex := -1 @@ -337,50 +370,27 @@ func (e *UserAuthEndpointer) handle(path string, f handlerFunc, v interface{}, v } } -func (e *UserAuthEndpointer) makeEndpoint(f handlerFunc, v interface{}) http.Handler { - return newServer( - authenticatedUser( - e.svc, - func(ctx context.Context, request interface{}) (interface{}, error) { - return f(ctx, request, e.svc) - }), - makeDecoder(v), - e.opts, - ) +func (e *authEndpointer) makeEndpoint(f handlerFunc, v interface{}) http.Handler { + next := func(ctx context.Context, request interface{}) (interface{}, error) { + return f(ctx, request, e.svc) + } + return newServer(e.authFunc(e.svc, next), makeDecoder(v), e.opts) } -func (e *UserAuthEndpointer) StartingAtVersion(version string) *UserAuthEndpointer { - return &UserAuthEndpointer{ - svc: e.svc, - opts: e.opts, - r: e.r, - versions: e.versions, - startingAtVersion: version, - endingAtVersion: e.endingAtVersion, - alternativePaths: e.alternativePaths, - } +func (e *authEndpointer) StartingAtVersion(version string) *authEndpointer { + ae := *e + ae.startingAtVersion = version + return &ae } -func (e *UserAuthEndpointer) EndingAtVersion(version string) *UserAuthEndpointer { - return &UserAuthEndpointer{ - svc: e.svc, - opts: e.opts, - r: e.r, - versions: e.versions, - startingAtVersion: e.startingAtVersion, - endingAtVersion: version, - alternativePaths: e.alternativePaths, - } +func (e *authEndpointer) EndingAtVersion(version string) *authEndpointer { + ae := *e + ae.endingAtVersion = version + return &ae } -func (e *UserAuthEndpointer) WithAltPaths(paths ...string) *UserAuthEndpointer { - return &UserAuthEndpointer{ - svc: e.svc, - opts: e.opts, - r: e.r, - versions: e.versions, - startingAtVersion: e.startingAtVersion, - endingAtVersion: e.endingAtVersion, - alternativePaths: paths, - } +func (e *authEndpointer) WithAltPaths(paths ...string) *authEndpointer { + ae := *e + ae.alternativePaths = paths + return &ae } diff --git a/server/service/endpoint_utils_test.go b/server/service/endpoint_utils_test.go index af8d55acec..e925a38e98 100644 --- a/server/service/endpoint_utils_test.go +++ b/server/service/endpoint_utils_test.go @@ -291,7 +291,7 @@ func TestEndpointer(t *testing.T) { ), } - e := NewUserAuthenticatedEndpointer(svc, fleetAPIOptions, r, "v1", "2021-11") + e := newUserAuthenticatedEndpointer(svc, fleetAPIOptions, r, "v1", "2021-11") nopHandler := func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { if authctx, ok := authz_ctx.FromContext(ctx); ok { authctx.SetChecked() diff --git a/server/service/handler.go b/server/service/handler.go index f0d7b3a55f..869caa4cab 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -24,23 +24,18 @@ import ( // FleetEndpoints is a collection of RPC endpoints implemented by the Fleet API. type FleetEndpoints struct { - Login endpoint.Endpoint - Logout endpoint.Endpoint - ForgotPassword endpoint.Endpoint - ResetPassword endpoint.Endpoint - CreateUserWithInvite endpoint.Endpoint - PerformRequiredPasswordReset endpoint.Endpoint - VerifyInvite endpoint.Endpoint - EnrollAgent endpoint.Endpoint - GetClientConfig endpoint.Endpoint - GetDistributedQueries endpoint.Endpoint - SubmitDistributedQueryResults endpoint.Endpoint - SubmitLogs endpoint.Endpoint - CarveBegin endpoint.Endpoint - CarveBlock endpoint.Endpoint - InitiateSSO endpoint.Endpoint - CallbackSSO endpoint.Endpoint - SSOSettings endpoint.Endpoint + Login endpoint.Endpoint + Logout endpoint.Endpoint + ForgotPassword endpoint.Endpoint + ResetPassword endpoint.Endpoint + CreateUserWithInvite endpoint.Endpoint + PerformRequiredPasswordReset endpoint.Endpoint + VerifyInvite endpoint.Endpoint + EnrollAgent endpoint.Endpoint + CarveBlock endpoint.Endpoint + InitiateSSO endpoint.Endpoint + CallbackSSO endpoint.Endpoint + SSOSettings endpoint.Endpoint } // MakeFleetServerEndpoints creates the Fleet API endpoints. @@ -70,12 +65,6 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th // Osquery endpoints EnrollAgent: logged(makeEnrollAgentEndpoint(svc)), - // Authenticated osquery endpoints - GetClientConfig: authenticatedHost(svc, logger, makeGetClientConfigEndpoint(svc)), - GetDistributedQueries: authenticatedHost(svc, logger, makeGetDistributedQueriesEndpoint(svc)), - SubmitDistributedQueryResults: authenticatedHost(svc, logger, makeSubmitDistributedQueryResultsEndpoint(svc)), - SubmitLogs: authenticatedHost(svc, logger, makeSubmitLogsEndpoint(svc)), - CarveBegin: authenticatedHost(svc, logger, makeCarveBeginEndpoint(svc)), // For some reason osquery does not provide a node key with the block // data. Instead the carve session ID should be verified in the service // method. @@ -84,23 +73,18 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th } type fleetHandlers struct { - Login http.Handler - Logout http.Handler - ForgotPassword http.Handler - ResetPassword http.Handler - CreateUserWithInvite http.Handler - PerformRequiredPasswordReset http.Handler - VerifyInvite http.Handler - EnrollAgent http.Handler - GetClientConfig http.Handler - GetDistributedQueries http.Handler - SubmitDistributedQueryResults http.Handler - SubmitLogs http.Handler - CarveBegin http.Handler - CarveBlock http.Handler - InitiateSSO http.Handler - CallbackSSO http.Handler - SettingsSSO http.Handler + Login http.Handler + Logout http.Handler + ForgotPassword http.Handler + ResetPassword http.Handler + CreateUserWithInvite http.Handler + PerformRequiredPasswordReset http.Handler + VerifyInvite http.Handler + EnrollAgent http.Handler + CarveBlock http.Handler + InitiateSSO http.Handler + CallbackSSO http.Handler + SettingsSSO http.Handler } func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandlers { @@ -109,23 +93,18 @@ func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandle return kithttp.NewServer(e, decodeFn, encodeResponse, opts...) } return &fleetHandlers{ - Login: newServer(e.Login, decodeLoginRequest), - Logout: newServer(e.Logout, decodeNoParamsRequest), - ForgotPassword: newServer(e.ForgotPassword, decodeForgotPasswordRequest), - ResetPassword: newServer(e.ResetPassword, decodeResetPasswordRequest), - CreateUserWithInvite: newServer(e.CreateUserWithInvite, decodeCreateUserRequest), - PerformRequiredPasswordReset: newServer(e.PerformRequiredPasswordReset, decodePerformRequiredPasswordResetRequest), - VerifyInvite: newServer(e.VerifyInvite, decodeVerifyInviteRequest), - EnrollAgent: newServer(e.EnrollAgent, decodeEnrollAgentRequest), - GetClientConfig: newServer(e.GetClientConfig, decodeGetClientConfigRequest), - GetDistributedQueries: newServer(e.GetDistributedQueries, decodeGetDistributedQueriesRequest), - SubmitDistributedQueryResults: newServer(e.SubmitDistributedQueryResults, decodeSubmitDistributedQueryResultsRequest), - SubmitLogs: newServer(e.SubmitLogs, decodeSubmitLogsRequest), - CarveBegin: newServer(e.CarveBegin, decodeCarveBeginRequest), - CarveBlock: newServer(e.CarveBlock, decodeCarveBlockRequest), - InitiateSSO: newServer(e.InitiateSSO, decodeInitiateSSORequest), - CallbackSSO: newServer(e.CallbackSSO, decodeCallbackSSORequest), - SettingsSSO: newServer(e.SSOSettings, decodeNoParamsRequest), + Login: newServer(e.Login, decodeLoginRequest), + Logout: newServer(e.Logout, decodeNoParamsRequest), + ForgotPassword: newServer(e.ForgotPassword, decodeForgotPasswordRequest), + ResetPassword: newServer(e.ResetPassword, decodeResetPasswordRequest), + CreateUserWithInvite: newServer(e.CreateUserWithInvite, decodeCreateUserRequest), + PerformRequiredPasswordReset: newServer(e.PerformRequiredPasswordReset, decodePerformRequiredPasswordResetRequest), + VerifyInvite: newServer(e.VerifyInvite, decodeVerifyInviteRequest), + EnrollAgent: newServer(e.EnrollAgent, decodeEnrollAgentRequest), + CarveBlock: newServer(e.CarveBlock, decodeCarveBlockRequest), + InitiateSSO: newServer(e.InitiateSSO, decodeInitiateSSORequest), + CallbackSSO: newServer(e.CallbackSSO, decodeCallbackSSORequest), + SettingsSSO: newServer(e.SSOSettings, decodeNoParamsRequest), } } @@ -206,7 +185,7 @@ func MakeHandler(svc fleet.Service, config config.FleetConfig, logger kitlog.Log } attachFleetAPIRoutes(r, fleetHandlers) - attachNewStyleFleetAPIRoutes(r, svc, fleetAPIOptions) + attachNewStyleFleetAPIRoutes(r, svc, logger, fleetAPIOptions) // Results endpoint is handled different due to websockets use r.PathPrefix("/api/v1/fleet/results/"). @@ -310,154 +289,158 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) { r.Handle("/api/v1/fleet/users", h.CreateUserWithInvite).Methods("POST").Name("create_user_with_invite") r.Handle("/api/v1/fleet/invites/{token}", h.VerifyInvite).Methods("GET").Name("verify_invite") r.Handle("/api/v1/osquery/enroll", h.EnrollAgent).Methods("POST").Name("enroll_agent") - r.Handle("/api/v1/osquery/config", h.GetClientConfig).Methods("POST").Name("get_client_config") - r.Handle("/api/v1/osquery/distributed/read", h.GetDistributedQueries).Methods("POST").Name("get_distributed_queries") - r.Handle("/api/v1/osquery/distributed/write", h.SubmitDistributedQueryResults).Methods("POST").Name("submit_distributed_query_results") - r.Handle("/api/v1/osquery/log", h.SubmitLogs).Methods("POST").Name("submit_logs") - r.Handle("/api/v1/osquery/carve/begin", h.CarveBegin).Methods("POST").Name("carve_begin") r.Handle("/api/v1/osquery/carve/block", h.CarveBlock).Methods("POST").Name("carve_block") } -func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kithttp.ServerOption) { - e := NewUserAuthenticatedEndpointer(svc, opts, r, "v1") +func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, logger kitlog.Logger, opts []kithttp.ServerOption) { + // user-authenticated endpoints + ue := newUserAuthenticatedEndpointer(svc, opts, r, "v1") - e.GET("/api/_version_/fleet/me", meEndpoint, nil) - e.GET("/api/_version_/fleet/sessions/{id:[0-9]+}", getInfoAboutSessionEndpoint, getInfoAboutSessionRequest{}) - e.DELETE("/api/_version_/fleet/sessions/{id:[0-9]+}", deleteSessionEndpoint, deleteSessionRequest{}) + ue.GET("/api/_version_/fleet/me", meEndpoint, nil) + ue.GET("/api/_version_/fleet/sessions/{id:[0-9]+}", getInfoAboutSessionEndpoint, getInfoAboutSessionRequest{}) + ue.DELETE("/api/_version_/fleet/sessions/{id:[0-9]+}", deleteSessionEndpoint, deleteSessionRequest{}) - e.GET("/api/_version_/fleet/config/certificate", getCertificateEndpoint, nil) - e.GET("/api/_version_/fleet/config", getAppConfigEndpoint, nil) - e.PATCH("/api/_version_/fleet/config", modifyAppConfigEndpoint, modifyAppConfigRequest{}) - e.POST("/api/_version_/fleet/spec/enroll_secret", applyEnrollSecretSpecEndpoint, applyEnrollSecretSpecRequest{}) - e.GET("/api/_version_/fleet/spec/enroll_secret", getEnrollSecretSpecEndpoint, nil) - e.GET("/api/_version_/fleet/version", versionEndpoint, nil) + ue.GET("/api/_version_/fleet/config/certificate", getCertificateEndpoint, nil) + ue.GET("/api/_version_/fleet/config", getAppConfigEndpoint, nil) + ue.PATCH("/api/_version_/fleet/config", modifyAppConfigEndpoint, modifyAppConfigRequest{}) + ue.POST("/api/_version_/fleet/spec/enroll_secret", applyEnrollSecretSpecEndpoint, applyEnrollSecretSpecRequest{}) + ue.GET("/api/_version_/fleet/spec/enroll_secret", getEnrollSecretSpecEndpoint, nil) + ue.GET("/api/_version_/fleet/version", versionEndpoint, nil) - e.POST("/api/_version_/fleet/users/roles/spec", applyUserRoleSpecsEndpoint, applyUserRoleSpecsRequest{}) - e.POST("/api/_version_/fleet/translate", translatorEndpoint, translatorRequest{}) - e.POST("/api/_version_/fleet/spec/teams", applyTeamSpecsEndpoint, applyTeamSpecsRequest{}) - e.PATCH("/api/_version_/fleet/teams/{team_id:[0-9]+}/secrets", modifyTeamEnrollSecretsEndpoint, modifyTeamEnrollSecretsRequest{}) - e.POST("/api/_version_/fleet/teams", createTeamEndpoint, createTeamRequest{}) - e.GET("/api/_version_/fleet/teams", listTeamsEndpoint, listTeamsRequest{}) - e.GET("/api/_version_/fleet/teams/{id:[0-9]+}", getTeamEndpoint, getTeamRequest{}) - e.PATCH("/api/_version_/fleet/teams/{id:[0-9]+}", modifyTeamEndpoint, modifyTeamRequest{}) - e.DELETE("/api/_version_/fleet/teams/{id:[0-9]+}", deleteTeamEndpoint, deleteTeamRequest{}) - e.POST("/api/_version_/fleet/teams/{id:[0-9]+}/agent_options", modifyTeamAgentOptionsEndpoint, modifyTeamAgentOptionsRequest{}) - e.GET("/api/_version_/fleet/teams/{id:[0-9]+}/users", listTeamUsersEndpoint, listTeamUsersRequest{}) - e.PATCH("/api/_version_/fleet/teams/{id:[0-9]+}/users", addTeamUsersEndpoint, modifyTeamUsersRequest{}) - e.DELETE("/api/_version_/fleet/teams/{id:[0-9]+}/users", deleteTeamUsersEndpoint, modifyTeamUsersRequest{}) - e.GET("/api/_version_/fleet/teams/{id:[0-9]+}/secrets", teamEnrollSecretsEndpoint, teamEnrollSecretsRequest{}) + ue.POST("/api/_version_/fleet/users/roles/spec", applyUserRoleSpecsEndpoint, applyUserRoleSpecsRequest{}) + ue.POST("/api/_version_/fleet/translate", translatorEndpoint, translatorRequest{}) + ue.POST("/api/_version_/fleet/spec/teams", applyTeamSpecsEndpoint, applyTeamSpecsRequest{}) + ue.PATCH("/api/_version_/fleet/teams/{team_id:[0-9]+}/secrets", modifyTeamEnrollSecretsEndpoint, modifyTeamEnrollSecretsRequest{}) + ue.POST("/api/_version_/fleet/teams", createTeamEndpoint, createTeamRequest{}) + ue.GET("/api/_version_/fleet/teams", listTeamsEndpoint, listTeamsRequest{}) + ue.GET("/api/_version_/fleet/teams/{id:[0-9]+}", getTeamEndpoint, getTeamRequest{}) + ue.PATCH("/api/_version_/fleet/teams/{id:[0-9]+}", modifyTeamEndpoint, modifyTeamRequest{}) + ue.DELETE("/api/_version_/fleet/teams/{id:[0-9]+}", deleteTeamEndpoint, deleteTeamRequest{}) + ue.POST("/api/_version_/fleet/teams/{id:[0-9]+}/agent_options", modifyTeamAgentOptionsEndpoint, modifyTeamAgentOptionsRequest{}) + ue.GET("/api/_version_/fleet/teams/{id:[0-9]+}/users", listTeamUsersEndpoint, listTeamUsersRequest{}) + ue.PATCH("/api/_version_/fleet/teams/{id:[0-9]+}/users", addTeamUsersEndpoint, modifyTeamUsersRequest{}) + ue.DELETE("/api/_version_/fleet/teams/{id:[0-9]+}/users", deleteTeamUsersEndpoint, modifyTeamUsersRequest{}) + ue.GET("/api/_version_/fleet/teams/{id:[0-9]+}/secrets", teamEnrollSecretsEndpoint, teamEnrollSecretsRequest{}) // Alias /api/_version_/fleet/team/ -> /api/_version_/fleet/teams/ - e.WithAltPaths("/api/_version_/fleet/team/{team_id}/schedule").GET("/api/_version_/fleet/teams/{team_id}/schedule", getTeamScheduleEndpoint, getTeamScheduleRequest{}) - e.WithAltPaths("/api/_version_/fleet/team/{team_id}/schedule").POST("/api/_version_/fleet/teams/{team_id}/schedule", teamScheduleQueryEndpoint, teamScheduleQueryRequest{}) - e.WithAltPaths("/api/_version_/fleet/team/{team_id}/schedule/{scheduled_query_id}").PATCH("/api/_version_/fleet/teams/{team_id}/schedule/{scheduled_query_id}", modifyTeamScheduleEndpoint, modifyTeamScheduleRequest{}) - e.WithAltPaths("/api/_version_/fleet/team/{team_id}/schedule/{scheduled_query_id}").DELETE("/api/_version_/fleet/teams/{team_id}/schedule/{scheduled_query_id}", deleteTeamScheduleEndpoint, deleteTeamScheduleRequest{}) + ue.WithAltPaths("/api/_version_/fleet/team/{team_id}/schedule").GET("/api/_version_/fleet/teams/{team_id}/schedule", getTeamScheduleEndpoint, getTeamScheduleRequest{}) + ue.WithAltPaths("/api/_version_/fleet/team/{team_id}/schedule").POST("/api/_version_/fleet/teams/{team_id}/schedule", teamScheduleQueryEndpoint, teamScheduleQueryRequest{}) + ue.WithAltPaths("/api/_version_/fleet/team/{team_id}/schedule/{scheduled_query_id}").PATCH("/api/_version_/fleet/teams/{team_id}/schedule/{scheduled_query_id}", modifyTeamScheduleEndpoint, modifyTeamScheduleRequest{}) + ue.WithAltPaths("/api/_version_/fleet/team/{team_id}/schedule/{scheduled_query_id}").DELETE("/api/_version_/fleet/teams/{team_id}/schedule/{scheduled_query_id}", deleteTeamScheduleEndpoint, deleteTeamScheduleRequest{}) - e.GET("/api/_version_/fleet/users", listUsersEndpoint, listUsersRequest{}) - e.POST("/api/_version_/fleet/users/admin", createUserEndpoint, createUserRequest{}) - e.GET("/api/_version_/fleet/users/{id:[0-9]+}", getUserEndpoint, getUserRequest{}) - e.PATCH("/api/_version_/fleet/users/{id:[0-9]+}", modifyUserEndpoint, modifyUserRequest{}) - e.DELETE("/api/_version_/fleet/users/{id:[0-9]+}", deleteUserEndpoint, deleteUserRequest{}) - e.POST("/api/_version_/fleet/users/{id:[0-9]+}/require_password_reset", requirePasswordResetEndpoint, requirePasswordResetRequest{}) - e.GET("/api/_version_/fleet/users/{id:[0-9]+}/sessions", getInfoAboutSessionsForUserEndpoint, getInfoAboutSessionsForUserRequest{}) - e.DELETE("/api/_version_/fleet/users/{id:[0-9]+}/sessions", deleteSessionsForUserEndpoint, deleteSessionsForUserRequest{}) - e.POST("/api/_version_/fleet/change_password", changePasswordEndpoint, changePasswordRequest{}) + ue.GET("/api/_version_/fleet/users", listUsersEndpoint, listUsersRequest{}) + ue.POST("/api/_version_/fleet/users/admin", createUserEndpoint, createUserRequest{}) + ue.GET("/api/_version_/fleet/users/{id:[0-9]+}", getUserEndpoint, getUserRequest{}) + ue.PATCH("/api/_version_/fleet/users/{id:[0-9]+}", modifyUserEndpoint, modifyUserRequest{}) + ue.DELETE("/api/_version_/fleet/users/{id:[0-9]+}", deleteUserEndpoint, deleteUserRequest{}) + ue.POST("/api/_version_/fleet/users/{id:[0-9]+}/require_password_reset", requirePasswordResetEndpoint, requirePasswordResetRequest{}) + ue.GET("/api/_version_/fleet/users/{id:[0-9]+}/sessions", getInfoAboutSessionsForUserEndpoint, getInfoAboutSessionsForUserRequest{}) + ue.DELETE("/api/_version_/fleet/users/{id:[0-9]+}/sessions", deleteSessionsForUserEndpoint, deleteSessionsForUserRequest{}) + ue.POST("/api/_version_/fleet/change_password", changePasswordEndpoint, changePasswordRequest{}) - e.GET("/api/_version_/fleet/email/change/{token}", changeEmailEndpoint, changeEmailRequest{}) - e.POST("/api/_version_/fleet/targets", searchTargetsEndpoint, searchTargetsRequest{}) + ue.GET("/api/_version_/fleet/email/change/{token}", changeEmailEndpoint, changeEmailRequest{}) + ue.POST("/api/_version_/fleet/targets", searchTargetsEndpoint, searchTargetsRequest{}) - e.POST("/api/_version_/fleet/invites", createInviteEndpoint, createInviteRequest{}) - e.GET("/api/_version_/fleet/invites", listInvitesEndpoint, listInvitesRequest{}) - e.DELETE("/api/_version_/fleet/invites/{id:[0-9]+}", deleteInviteEndpoint, deleteInviteRequest{}) - e.PATCH("/api/_version_/fleet/invites/{id:[0-9]+}", updateInviteEndpoint, updateInviteRequest{}) + ue.POST("/api/_version_/fleet/invites", createInviteEndpoint, createInviteRequest{}) + ue.GET("/api/_version_/fleet/invites", listInvitesEndpoint, listInvitesRequest{}) + ue.DELETE("/api/_version_/fleet/invites/{id:[0-9]+}", deleteInviteEndpoint, deleteInviteRequest{}) + ue.PATCH("/api/_version_/fleet/invites/{id:[0-9]+}", updateInviteEndpoint, updateInviteRequest{}) - e.POST("/api/_version_/fleet/global/policies", globalPolicyEndpoint, globalPolicyRequest{}) - e.GET("/api/_version_/fleet/global/policies", listGlobalPoliciesEndpoint, nil) - e.GET("/api/_version_/fleet/global/policies/{policy_id}", getPolicyByIDEndpoint, getPolicyByIDRequest{}) - e.POST("/api/_version_/fleet/global/policies/delete", deleteGlobalPoliciesEndpoint, deleteGlobalPoliciesRequest{}) - e.PATCH("/api/_version_/fleet/global/policies/{policy_id}", modifyGlobalPolicyEndpoint, modifyGlobalPolicyRequest{}) + ue.POST("/api/_version_/fleet/global/policies", globalPolicyEndpoint, globalPolicyRequest{}) + ue.GET("/api/_version_/fleet/global/policies", listGlobalPoliciesEndpoint, nil) + ue.GET("/api/_version_/fleet/global/policies/{policy_id}", getPolicyByIDEndpoint, getPolicyByIDRequest{}) + ue.POST("/api/_version_/fleet/global/policies/delete", deleteGlobalPoliciesEndpoint, deleteGlobalPoliciesRequest{}) + ue.PATCH("/api/_version_/fleet/global/policies/{policy_id}", modifyGlobalPolicyEndpoint, modifyGlobalPolicyRequest{}) // Alias /api/_version_/fleet/team/ -> /api/_version_/fleet/teams/ - e.WithAltPaths("/api/_version_/fleet/team/{team_id}/policies").POST("/api/_version_/fleet/teams/{team_id}/policies", teamPolicyEndpoint, teamPolicyRequest{}) - e.WithAltPaths("/api/_version_/fleet/team/{team_id}/policies").GET("/api/_version_/fleet/teams/{team_id}/policies", listTeamPoliciesEndpoint, listTeamPoliciesRequest{}) - e.WithAltPaths("/api/_version_/fleet/team/{team_id}/policies/{policy_id}").GET("/api/_version_/fleet/teams/{team_id}/policies/{policy_id}", getTeamPolicyByIDEndpoint, getTeamPolicyByIDRequest{}) - e.WithAltPaths("/api/_version_/fleet/team/{team_id}/policies/delete").POST("/api/_version_/fleet/teams/{team_id}/policies/delete", deleteTeamPoliciesEndpoint, deleteTeamPoliciesRequest{}) - e.PATCH("/api/_version_/fleet/teams/{team_id}/policies/{policy_id}", modifyTeamPolicyEndpoint, modifyTeamPolicyRequest{}) - e.POST("/api/_version_/fleet/spec/policies", applyPolicySpecsEndpoint, applyPolicySpecsRequest{}) + ue.WithAltPaths("/api/_version_/fleet/team/{team_id}/policies").POST("/api/_version_/fleet/teams/{team_id}/policies", teamPolicyEndpoint, teamPolicyRequest{}) + ue.WithAltPaths("/api/_version_/fleet/team/{team_id}/policies").GET("/api/_version_/fleet/teams/{team_id}/policies", listTeamPoliciesEndpoint, listTeamPoliciesRequest{}) + ue.WithAltPaths("/api/_version_/fleet/team/{team_id}/policies/{policy_id}").GET("/api/_version_/fleet/teams/{team_id}/policies/{policy_id}", getTeamPolicyByIDEndpoint, getTeamPolicyByIDRequest{}) + ue.WithAltPaths("/api/_version_/fleet/team/{team_id}/policies/delete").POST("/api/_version_/fleet/teams/{team_id}/policies/delete", deleteTeamPoliciesEndpoint, deleteTeamPoliciesRequest{}) + ue.PATCH("/api/_version_/fleet/teams/{team_id}/policies/{policy_id}", modifyTeamPolicyEndpoint, modifyTeamPolicyRequest{}) + ue.POST("/api/_version_/fleet/spec/policies", applyPolicySpecsEndpoint, applyPolicySpecsRequest{}) - e.GET("/api/_version_/fleet/queries/{id:[0-9]+}", getQueryEndpoint, getQueryRequest{}) - e.GET("/api/_version_/fleet/queries", listQueriesEndpoint, listQueriesRequest{}) - e.POST("/api/_version_/fleet/queries", createQueryEndpoint, createQueryRequest{}) - e.PATCH("/api/_version_/fleet/queries/{id:[0-9]+}", modifyQueryEndpoint, modifyQueryRequest{}) - e.DELETE("/api/_version_/fleet/queries/{name}", deleteQueryEndpoint, deleteQueryRequest{}) - e.DELETE("/api/_version_/fleet/queries/id/{id:[0-9]+}", deleteQueryByIDEndpoint, deleteQueryByIDRequest{}) - e.POST("/api/_version_/fleet/queries/delete", deleteQueriesEndpoint, deleteQueriesRequest{}) - e.POST("/api/_version_/fleet/spec/queries", applyQuerySpecsEndpoint, applyQuerySpecsRequest{}) - e.GET("/api/_version_/fleet/spec/queries", getQuerySpecsEndpoint, nil) - e.GET("/api/_version_/fleet/spec/queries/{name}", getQuerySpecEndpoint, getGenericSpecRequest{}) + ue.GET("/api/_version_/fleet/queries/{id:[0-9]+}", getQueryEndpoint, getQueryRequest{}) + ue.GET("/api/_version_/fleet/queries", listQueriesEndpoint, listQueriesRequest{}) + ue.POST("/api/_version_/fleet/queries", createQueryEndpoint, createQueryRequest{}) + ue.PATCH("/api/_version_/fleet/queries/{id:[0-9]+}", modifyQueryEndpoint, modifyQueryRequest{}) + ue.DELETE("/api/_version_/fleet/queries/{name}", deleteQueryEndpoint, deleteQueryRequest{}) + ue.DELETE("/api/_version_/fleet/queries/id/{id:[0-9]+}", deleteQueryByIDEndpoint, deleteQueryByIDRequest{}) + ue.POST("/api/_version_/fleet/queries/delete", deleteQueriesEndpoint, deleteQueriesRequest{}) + ue.POST("/api/_version_/fleet/spec/queries", applyQuerySpecsEndpoint, applyQuerySpecsRequest{}) + ue.GET("/api/_version_/fleet/spec/queries", getQuerySpecsEndpoint, nil) + ue.GET("/api/_version_/fleet/spec/queries/{name}", getQuerySpecEndpoint, getGenericSpecRequest{}) - e.GET("/api/_version_/fleet/packs/{id:[0-9]+}/scheduled", getScheduledQueriesInPackEndpoint, getScheduledQueriesInPackRequest{}) - e.POST("/api/_version_/fleet/schedule", scheduleQueryEndpoint, scheduleQueryRequest{}) - e.GET("/api/_version_/fleet/schedule/{id:[0-9]+}", getScheduledQueryEndpoint, getScheduledQueryRequest{}) - e.PATCH("/api/_version_/fleet/schedule/{id:[0-9]+}", modifyScheduledQueryEndpoint, modifyScheduledQueryRequest{}) - e.DELETE("/api/_version_/fleet/schedule/{id:[0-9]+}", deleteScheduledQueryEndpoint, deleteScheduledQueryRequest{}) + ue.GET("/api/_version_/fleet/packs/{id:[0-9]+}/scheduled", getScheduledQueriesInPackEndpoint, getScheduledQueriesInPackRequest{}) + ue.POST("/api/_version_/fleet/schedule", scheduleQueryEndpoint, scheduleQueryRequest{}) + ue.GET("/api/_version_/fleet/schedule/{id:[0-9]+}", getScheduledQueryEndpoint, getScheduledQueryRequest{}) + ue.PATCH("/api/_version_/fleet/schedule/{id:[0-9]+}", modifyScheduledQueryEndpoint, modifyScheduledQueryRequest{}) + ue.DELETE("/api/_version_/fleet/schedule/{id:[0-9]+}", deleteScheduledQueryEndpoint, deleteScheduledQueryRequest{}) - e.GET("/api/_version_/fleet/packs/{id:[0-9]+}", getPackEndpoint, getPackRequest{}) - e.POST("/api/_version_/fleet/packs", createPackEndpoint, createPackRequest{}) - e.PATCH("/api/_version_/fleet/packs/{id:[0-9]+}", modifyPackEndpoint, modifyPackRequest{}) - e.GET("/api/_version_/fleet/packs", listPacksEndpoint, listPacksRequest{}) - e.DELETE("/api/_version_/fleet/packs/{name}", deletePackEndpoint, deletePackRequest{}) - e.DELETE("/api/_version_/fleet/packs/id/{id:[0-9]+}", deletePackByIDEndpoint, deletePackByIDRequest{}) - e.POST("/api/_version_/fleet/spec/packs", applyPackSpecsEndpoint, applyPackSpecsRequest{}) - e.GET("/api/_version_/fleet/spec/packs", getPackSpecsEndpoint, nil) - e.GET("/api/_version_/fleet/spec/packs/{name}", getPackSpecEndpoint, getGenericSpecRequest{}) + ue.GET("/api/_version_/fleet/packs/{id:[0-9]+}", getPackEndpoint, getPackRequest{}) + ue.POST("/api/_version_/fleet/packs", createPackEndpoint, createPackRequest{}) + ue.PATCH("/api/_version_/fleet/packs/{id:[0-9]+}", modifyPackEndpoint, modifyPackRequest{}) + ue.GET("/api/_version_/fleet/packs", listPacksEndpoint, listPacksRequest{}) + ue.DELETE("/api/_version_/fleet/packs/{name}", deletePackEndpoint, deletePackRequest{}) + ue.DELETE("/api/_version_/fleet/packs/id/{id:[0-9]+}", deletePackByIDEndpoint, deletePackByIDRequest{}) + ue.POST("/api/_version_/fleet/spec/packs", applyPackSpecsEndpoint, applyPackSpecsRequest{}) + ue.GET("/api/_version_/fleet/spec/packs", getPackSpecsEndpoint, nil) + ue.GET("/api/_version_/fleet/spec/packs/{name}", getPackSpecEndpoint, getGenericSpecRequest{}) - e.GET("/api/_version_/fleet/software", listSoftwareEndpoint, listSoftwareRequest{}) - e.GET("/api/_version_/fleet/software/count", countSoftwareEndpoint, countSoftwareRequest{}) + ue.GET("/api/_version_/fleet/software", listSoftwareEndpoint, listSoftwareRequest{}) + ue.GET("/api/_version_/fleet/software/count", countSoftwareEndpoint, countSoftwareRequest{}) - e.GET("/api/_version_/fleet/host_summary", getHostSummaryEndpoint, getHostSummaryRequest{}) - e.GET("/api/_version_/fleet/hosts", listHostsEndpoint, listHostsRequest{}) - e.POST("/api/_version_/fleet/hosts/delete", deleteHostsEndpoint, deleteHostsRequest{}) - e.GET("/api/_version_/fleet/hosts/{id:[0-9]+}", getHostEndpoint, getHostRequest{}) - e.GET("/api/_version_/fleet/hosts/count", countHostsEndpoint, countHostsRequest{}) - e.GET("/api/_version_/fleet/hosts/identifier/{identifier}", hostByIdentifierEndpoint, hostByIdentifierRequest{}) - e.DELETE("/api/_version_/fleet/hosts/{id:[0-9]+}", deleteHostEndpoint, deleteHostRequest{}) - e.POST("/api/_version_/fleet/hosts/transfer", addHostsToTeamEndpoint, addHostsToTeamRequest{}) - e.POST("/api/_version_/fleet/hosts/transfer/filter", addHostsToTeamByFilterEndpoint, addHostsToTeamByFilterRequest{}) - e.POST("/api/_version_/fleet/hosts/{id:[0-9]+}/refetch", refetchHostEndpoint, refetchHostRequest{}) - e.GET("/api/_version_/fleet/hosts/{id:[0-9]+}/device_mapping", listHostDeviceMappingEndpoint, listHostDeviceMappingRequest{}) + ue.GET("/api/_version_/fleet/host_summary", getHostSummaryEndpoint, getHostSummaryRequest{}) + ue.GET("/api/_version_/fleet/hosts", listHostsEndpoint, listHostsRequest{}) + ue.POST("/api/_version_/fleet/hosts/delete", deleteHostsEndpoint, deleteHostsRequest{}) + ue.GET("/api/_version_/fleet/hosts/{id:[0-9]+}", getHostEndpoint, getHostRequest{}) + ue.GET("/api/_version_/fleet/hosts/count", countHostsEndpoint, countHostsRequest{}) + ue.GET("/api/_version_/fleet/hosts/identifier/{identifier}", hostByIdentifierEndpoint, hostByIdentifierRequest{}) + ue.DELETE("/api/_version_/fleet/hosts/{id:[0-9]+}", deleteHostEndpoint, deleteHostRequest{}) + ue.POST("/api/_version_/fleet/hosts/transfer", addHostsToTeamEndpoint, addHostsToTeamRequest{}) + ue.POST("/api/_version_/fleet/hosts/transfer/filter", addHostsToTeamByFilterEndpoint, addHostsToTeamByFilterRequest{}) + ue.POST("/api/_version_/fleet/hosts/{id:[0-9]+}/refetch", refetchHostEndpoint, refetchHostRequest{}) + ue.GET("/api/_version_/fleet/hosts/{id:[0-9]+}/device_mapping", listHostDeviceMappingEndpoint, listHostDeviceMappingRequest{}) - e.POST("/api/_version_/fleet/labels", createLabelEndpoint, createLabelRequest{}) - e.PATCH("/api/_version_/fleet/labels/{id:[0-9]+}", modifyLabelEndpoint, modifyLabelRequest{}) - e.GET("/api/_version_/fleet/labels/{id:[0-9]+}", getLabelEndpoint, getLabelRequest{}) - e.GET("/api/_version_/fleet/labels", listLabelsEndpoint, listLabelsRequest{}) - e.GET("/api/_version_/fleet/labels/{id:[0-9]+}/hosts", listHostsInLabelEndpoint, listHostsInLabelRequest{}) - e.DELETE("/api/_version_/fleet/labels/{name}", deleteLabelEndpoint, deleteLabelRequest{}) - e.DELETE("/api/_version_/fleet/labels/id/{id:[0-9]+}", deleteLabelByIDEndpoint, deleteLabelByIDRequest{}) - e.POST("/api/_version_/fleet/spec/labels", applyLabelSpecsEndpoint, applyLabelSpecsRequest{}) - e.GET("/api/_version_/fleet/spec/labels", getLabelSpecsEndpoint, nil) - e.GET("/api/_version_/fleet/spec/labels/{name}", getLabelSpecEndpoint, getGenericSpecRequest{}) + ue.POST("/api/_version_/fleet/labels", createLabelEndpoint, createLabelRequest{}) + ue.PATCH("/api/_version_/fleet/labels/{id:[0-9]+}", modifyLabelEndpoint, modifyLabelRequest{}) + ue.GET("/api/_version_/fleet/labels/{id:[0-9]+}", getLabelEndpoint, getLabelRequest{}) + ue.GET("/api/_version_/fleet/labels", listLabelsEndpoint, listLabelsRequest{}) + ue.GET("/api/_version_/fleet/labels/{id:[0-9]+}/hosts", listHostsInLabelEndpoint, listHostsInLabelRequest{}) + ue.DELETE("/api/_version_/fleet/labels/{name}", deleteLabelEndpoint, deleteLabelRequest{}) + ue.DELETE("/api/_version_/fleet/labels/id/{id:[0-9]+}", deleteLabelByIDEndpoint, deleteLabelByIDRequest{}) + ue.POST("/api/_version_/fleet/spec/labels", applyLabelSpecsEndpoint, applyLabelSpecsRequest{}) + ue.GET("/api/_version_/fleet/spec/labels", getLabelSpecsEndpoint, nil) + ue.GET("/api/_version_/fleet/spec/labels/{name}", getLabelSpecEndpoint, getGenericSpecRequest{}) - e.GET("/api/_version_/fleet/queries/run", runLiveQueryEndpoint, runLiveQueryRequest{}) - e.POST("/api/_version_/fleet/queries/run", createDistributedQueryCampaignEndpoint, createDistributedQueryCampaignRequest{}) - e.POST("/api/_version_/fleet/queries/run_by_names", createDistributedQueryCampaignByNamesEndpoint, createDistributedQueryCampaignByNamesRequest{}) + ue.GET("/api/_version_/fleet/queries/run", runLiveQueryEndpoint, runLiveQueryRequest{}) + ue.POST("/api/_version_/fleet/queries/run", createDistributedQueryCampaignEndpoint, createDistributedQueryCampaignRequest{}) + ue.POST("/api/_version_/fleet/queries/run_by_names", createDistributedQueryCampaignByNamesEndpoint, createDistributedQueryCampaignByNamesRequest{}) - e.GET("/api/_version_/fleet/activities", listActivitiesEndpoint, listActivitiesRequest{}) + ue.GET("/api/_version_/fleet/activities", listActivitiesEndpoint, listActivitiesRequest{}) - e.GET("/api/_version_/fleet/global/schedule", getGlobalScheduleEndpoint, getGlobalScheduleRequest{}) - e.POST("/api/_version_/fleet/global/schedule", globalScheduleQueryEndpoint, globalScheduleQueryRequest{}) - e.PATCH("/api/_version_/fleet/global/schedule/{id:[0-9]+}", modifyGlobalScheduleEndpoint, modifyGlobalScheduleRequest{}) - e.DELETE("/api/_version_/fleet/global/schedule/{id:[0-9]+}", deleteGlobalScheduleEndpoint, deleteGlobalScheduleRequest{}) + ue.GET("/api/_version_/fleet/global/schedule", getGlobalScheduleEndpoint, getGlobalScheduleRequest{}) + ue.POST("/api/_version_/fleet/global/schedule", globalScheduleQueryEndpoint, globalScheduleQueryRequest{}) + ue.PATCH("/api/_version_/fleet/global/schedule/{id:[0-9]+}", modifyGlobalScheduleEndpoint, modifyGlobalScheduleRequest{}) + ue.DELETE("/api/_version_/fleet/global/schedule/{id:[0-9]+}", deleteGlobalScheduleEndpoint, deleteGlobalScheduleRequest{}) - e.GET("/api/_version_/fleet/carves", listCarvesEndpoint, listCarvesRequest{}) - e.GET("/api/_version_/fleet/carves/{id:[0-9]+}", getCarveEndpoint, getCarveRequest{}) - e.GET("/api/_version_/fleet/carves/{id:[0-9]+}/block/{block_id}", getCarveBlockEndpoint, getCarveBlockRequest{}) + ue.GET("/api/_version_/fleet/carves", listCarvesEndpoint, listCarvesRequest{}) + ue.GET("/api/_version_/fleet/carves/{id:[0-9]+}", getCarveEndpoint, getCarveRequest{}) + ue.GET("/api/_version_/fleet/carves/{id:[0-9]+}/block/{block_id}", getCarveBlockEndpoint, getCarveBlockRequest{}) - e.GET("/api/_version_/fleet/hosts/{id:[0-9]+}/macadmins", getMacadminsDataEndpoint, getMacadminsDataRequest{}) - e.GET("/api/_version_/fleet/macadmins", getAggregatedMacadminsDataEndpoint, getAggregatedMacadminsDataRequest{}) + ue.GET("/api/_version_/fleet/hosts/{id:[0-9]+}/macadmins", getMacadminsDataEndpoint, getMacadminsDataRequest{}) + ue.GET("/api/_version_/fleet/macadmins", getAggregatedMacadminsDataEndpoint, getAggregatedMacadminsDataRequest{}) - e.GET("/api/_version_/fleet/status/result_store", statusResultStoreEndpoint, nil) - e.GET("/api/_version_/fleet/status/live_query", statusLiveQueryEndpoint, nil) + ue.GET("/api/_version_/fleet/status/result_store", statusResultStoreEndpoint, nil) + ue.GET("/api/_version_/fleet/status/live_query", statusLiveQueryEndpoint, nil) + + // host-authenticated endpoints + he := newHostAuthenticatedEndpointer(svc, logger, opts, r, "v1") + he.POST("/api/_version_/osquery/config", getClientConfigEndpoint, getClientConfigRequest{}) + he.POST("/api/_version_/osquery/distributed/read", getDistributedQueriesEndpoint, getDistributedQueriesRequest{}) + he.POST("/api/_version_/osquery/distributed/write", submitDistributedQueryResultsEndpoint, submitDistributedQueryResultsRequestShim{}) + he.POST("/api/_version_/osquery/carve/begin", carveBeginEndpoint, carveBeginRequest{}) + he.POST("/api/_version_/osquery/log", submitLogsEndpoint, submitLogsRequest{}) } // TODO: this duplicates the one in makeKitHandler diff --git a/server/service/handler_test.go b/server/service/handler_test.go index 3faed0c01f..045a6a9619 100644 --- a/server/service/handler_test.go +++ b/server/service/handler_test.go @@ -58,22 +58,6 @@ func TestAPIRoutes(t *testing.T) { verb: "POST", uri: "/api/v1/osquery/enroll", }, - { - verb: "POST", - uri: "/api/v1/osquery/config", - }, - { - verb: "POST", - uri: "/api/v1/osquery/distributed/read", - }, - { - verb: "POST", - uri: "/api/v1/osquery/distributed/write", - }, - { - verb: "POST", - uri: "/api/v1/osquery/log", - }, } for _, route := range routes { diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 693c1249a3..34102912aa 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -3079,6 +3079,21 @@ func (s *integrationTestSuite) TestStatus() { s.DoJSON("GET", "/api/v1/fleet/status/live_query", nil, http.StatusOK, &statusResp) } +func (s *integrationTestSuite) TestOsqueryConfig() { + t := s.T() + + hosts := s.createHosts(t) + req := getClientConfigRequest{NodeKey: hosts[0].NodeKey} + var resp getClientConfigResponse + s.DoJSON("POST", "/api/v1/osquery/config", req, http.StatusOK, &resp) + + // test with invalid node key + var errRes map[string]interface{} + req.NodeKey += "zzzz" + s.DoJSON("POST", "/api/v1/osquery/config", req, http.StatusUnauthorized, &errRes) + assert.Contains(t, errRes["error"], "invalid node key") +} + // creates a session and returns it, its key is to be passed as authorization header. func createSession(t *testing.T, uid uint, ds fleet.Datastore) *fleet.Session { key := make([]byte, 64) diff --git a/server/service/integration_live_queries_test.go b/server/service/integration_live_queries_test.go index 9878838a97..b263a6bd15 100644 --- a/server/service/integration_live_queries_test.go +++ b/server/service/integration_live_queries_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "encoding/json" "fmt" "net/http" "os" @@ -65,6 +66,11 @@ func (s *liveQueriesTestSuite) SetupSuite() { } } +func (s *liveQueriesTestSuite) TearDownTest() { + // reset the mock + s.lq.Mock = mock.Mock{} +} + func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { t := s.T() @@ -96,13 +102,16 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { cid := getCIDForQ(s, q1) - distributedReq := SubmitDistributedQueryResultsRequest{ + distributedReq := submitDistributedQueryResultsRequestShim{ NodeKey: host.NodeKey, - Results: map[string][]map[string]string{ - hostDistributedQueryPrefix + cid: {{"col1": "a", "col2": "b"}}, + Results: map[string]json.RawMessage{ + hostDistributedQueryPrefix + cid: json.RawMessage(`[{"col1": "a", "col2": "b"}]`), + hostDistributedQueryPrefix + "invalidcid": json.RawMessage(`""`), // empty string is sometimes sent for no results + hostDistributedQueryPrefix + "9999": json.RawMessage(`""`), }, - Statuses: map[string]fleet.OsqueryStatus{ - hostDistributedQueryPrefix + cid: 0, + Statuses: map[string]interface{}{ + hostDistributedQueryPrefix + cid: 0, + hostDistributedQueryPrefix + "9999": "0", }, Messages: map[string]string{ hostDistributedQueryPrefix + cid: "some msg", @@ -353,13 +362,13 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsOnSomeHost() { // Give the above call a couple of seconds to create the campaign time.Sleep(2 * time.Second) cid1 := getCIDForQ(s, q1) - distributedReq := SubmitDistributedQueryResultsRequest{ + distributedReq := submitDistributedQueryResultsRequestShim{ NodeKey: h1.NodeKey, - Results: map[string][]map[string]string{ - hostDistributedQueryPrefix + cid1: {{"col1": "a", "col2": "b"}}, + Results: map[string]json.RawMessage{ + hostDistributedQueryPrefix + cid1: json.RawMessage(`[{"col1": "a", "col2": "b"}]`), }, - Statuses: map[string]fleet.OsqueryStatus{ - hostDistributedQueryPrefix + cid1: 0, + Statuses: map[string]interface{}{ + hostDistributedQueryPrefix + cid1: "0", }, Messages: map[string]string{ hostDistributedQueryPrefix + cid1: "some msg", @@ -368,12 +377,12 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsOnSomeHost() { distributedResp := submitDistributedQueryResultsResponse{} s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp) - distributedReq = SubmitDistributedQueryResultsRequest{ + distributedReq = submitDistributedQueryResultsRequestShim{ NodeKey: h2.NodeKey, - Results: map[string][]map[string]string{ - hostDistributedQueryPrefix + cid1: {}, + Results: map[string]json.RawMessage{ + hostDistributedQueryPrefix + cid1: json.RawMessage(`""`), }, - Statuses: map[string]fleet.OsqueryStatus{ + Statuses: map[string]interface{}{ hostDistributedQueryPrefix + cid1: 123, }, Messages: map[string]string{ @@ -449,3 +458,21 @@ func (s *liveQueriesTestSuite) TestCreateDistributedQueryCampaign() { QuerySQL: "SELECT 3", Selected: distributedQueryCampaignTargetsByNames{Hosts: []string{h1.Hostname + "ZZZZZ"}}}, http.StatusOK, &createResp) } + +func (s *liveQueriesTestSuite) TestOsqueryDistributedRead() { + t := s.T() + + hostID := s.hosts[1].ID + s.lq.On("QueriesForHost", hostID).Return(map[string]string{fmt.Sprintf("%d", hostID): "select 1 from osquery;"}, nil) + + req := getDistributedQueriesRequest{NodeKey: s.hosts[1].NodeKey} + var resp getDistributedQueriesResponse + s.DoJSON("POST", "/api/v1/osquery/distributed/read", req, http.StatusOK, &resp) + assert.Contains(t, resp.Queries, hostDistributedQueryPrefix+fmt.Sprintf("%d", hostID)) + + // test with invalid node key + var errRes map[string]interface{} + req.NodeKey += "zzzz" + s.DoJSON("POST", "/api/v1/osquery/distributed/read", req, http.StatusUnauthorized, &errRes) + assert.Contains(t, errRes["error"], "invalid node key") +} diff --git a/server/service/integration_logger_test.go b/server/service/integration_logger_test.go index 948c5c9b94..63486ffd71 100644 --- a/server/service/integration_logger_test.go +++ b/server/service/integration_logger_test.go @@ -2,8 +2,10 @@ package service import ( "bytes" + "compress/gzip" "context" "encoding/json" + "fmt" "io" "net/http" "strings" @@ -121,15 +123,14 @@ func (s *integrationLoggerTestSuite) TestOsqueryEndpointsLogErrors() { require.Nil(t, err) logString := s.buf.String() - assert.Contains(t, logString, `{"err":"decoding JSON:`) assert.Contains(t, logString, `invalid character '}' looking for beginning of value","level":"info","path":"/api/v1/osquery/log"} `, logString) } -func (s *integrationLoggerTestSuite) TestSubmitStatusLog() { +func (s *integrationLoggerTestSuite) TestSubmitLog() { t := s.T() - _, err := s.ds.NewHost(context.Background(), &fleet.Host{ + h, err := s.ds.NewHost(context.Background(), &fleet.Host{ DetailUpdatedAt: time.Now(), LabelUpdatedAt: time.Now(), PolicyUpdatedAt: time.Now(), @@ -143,8 +144,9 @@ func (s *integrationLoggerTestSuite) TestSubmitStatusLog() { }) require.NoError(t, err) + // submit status logs req := submitLogsRequest{ - NodeKey: "1234", + NodeKey: h.NodeKey, LogType: "status", Data: nil, } @@ -152,8 +154,53 @@ func (s *integrationLoggerTestSuite) TestSubmitStatusLog() { s.DoJSON("POST", "/api/v1/osquery/log", req, http.StatusOK, &res) logString := s.buf.String() - assert.Equal(t, 1, strings.Count(logString, "\"ip_addr\"")) + assert.Equal(t, 1, strings.Count(logString, `"ip_addr"`)) assert.Equal(t, 1, strings.Count(logString, "x_for_ip_addr")) + s.buf.Reset() + + // submit results logs + req = submitLogsRequest{ + NodeKey: h.NodeKey, + LogType: "result", + Data: nil, + } + res = submitLogsResponse{} + s.DoJSON("POST", "/api/v1/osquery/log", req, http.StatusOK, &res) + + logString = s.buf.String() + assert.Equal(t, 1, strings.Count(logString, `"ip_addr"`)) + assert.Equal(t, 1, strings.Count(logString, "x_for_ip_addr")) + s.buf.Reset() + + // submit invalid type logs + req = submitLogsRequest{ + NodeKey: h.NodeKey, + LogType: "unknown", + Data: nil, + } + var errRes map[string]string + s.DoJSON("POST", "/api/v1/osquery/log", req, http.StatusInternalServerError, &errRes) + assert.Contains(t, errRes["error"], "unknown log type") + s.buf.Reset() + + // submit gzip-encoded request + var body bytes.Buffer + gw := gzip.NewWriter(&body) + _, err = fmt.Fprintf(gw, `{ + "node_key": %q, + "log_type": "status", + "data": null + }`, h.NodeKey) + require.NoError(t, err) + require.NoError(t, gw.Close()) + + s.DoRawWithHeaders("POST", "/api/v1/osquery/log", body.Bytes(), http.StatusOK, map[string]string{"Content-Encoding": "gzip"}) + logString = s.buf.String() + assert.Equal(t, 1, strings.Count(logString, `"ip_addr"`)) + assert.Equal(t, 1, strings.Count(logString, "x_for_ip_addr")) + + // submit same payload without specifying gzip encoding fails + s.DoRawWithHeaders("POST", "/api/v1/osquery/log", body.Bytes(), http.StatusInternalServerError, nil) } func (s *integrationLoggerTestSuite) TestEnrollAgentLogsErrors() { diff --git a/server/service/osquery.go b/server/service/osquery.go new file mode 100644 index 0000000000..22db162ee5 --- /dev/null +++ b/server/service/osquery.go @@ -0,0 +1,869 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + "time" + + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" + hostctx "github.com/fleetdm/fleet/v4/server/contexts/host" + "github.com/fleetdm/fleet/v4/server/contexts/logging" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/fleetdm/fleet/v4/server/pubsub" + "github.com/fleetdm/fleet/v4/server/service/osquery_utils" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/log/level" + "github.com/spf13/cast" +) + +//////////////////////////////////////////////////////////////////////////////// +// Get Client Config +//////////////////////////////////////////////////////////////////////////////// + +type getClientConfigRequest struct { + NodeKey string `json:"node_key"` +} + +type getClientConfigResponse struct { + Config map[string]interface{} + Err error `json:"error,omitempty"` +} + +func (r getClientConfigResponse) error() error { return r.Err } + +func getClientConfigEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + config, err := svc.GetClientConfig(ctx) + if err != nil { + return getClientConfigResponse{Err: err}, nil + } + + // We return the config here explicitly because osquery exepects the + // response for configs to be at the top-level of the JSON response + return config, nil +} + +func (svc *Service) GetClientConfig(ctx context.Context) (map[string]interface{}, error) { + // skipauth: Authorization is currently for user endpoints only. + svc.authz.SkipAuthorization(ctx) + + host, ok := hostctx.FromContext(ctx) + if !ok { + return nil, osqueryError{message: "internal error: missing host from request context"} + } + + baseConfig, err := svc.AgentOptionsForHost(ctx, host.TeamID, host.Platform) + if err != nil { + return nil, osqueryError{message: "internal error: fetch base config: " + err.Error()} + } + + config := make(map[string]interface{}) + if baseConfig != nil { + err = json.Unmarshal(baseConfig, &config) + if err != nil { + return nil, osqueryError{message: "internal error: parse base configuration: " + err.Error()} + } + } + + packs, err := svc.ds.ListPacksForHost(ctx, host.ID) + if err != nil { + return nil, osqueryError{message: "database error: " + err.Error()} + } + + packConfig := fleet.Packs{} + for _, pack := range packs { + // first, we must figure out what queries are in this pack + queries, err := svc.ds.ListScheduledQueriesInPack(ctx, pack.ID) + if err != nil { + return nil, osqueryError{message: "database error: " + err.Error()} + } + + // the serializable osquery config struct expects content in a + // particular format, so we do the conversion here + configQueries := fleet.Queries{} + for _, query := range queries { + queryContent := fleet.QueryContent{ + Query: query.Query, + Interval: query.Interval, + Platform: query.Platform, + Version: query.Version, + Removed: query.Removed, + Shard: query.Shard, + Denylist: query.Denylist, + } + + if query.Removed != nil { + queryContent.Removed = query.Removed + } + + if query.Snapshot != nil && *query.Snapshot { + queryContent.Snapshot = query.Snapshot + } + + configQueries[query.Name] = queryContent + } + + // finally, we add the pack to the client config struct with all of + // the pack's queries + packConfig[pack.Name] = fleet.PackContent{ + Platform: pack.Platform, + Queries: configQueries, + } + } + + if len(packConfig) > 0 { + packJSON, err := json.Marshal(packConfig) + if err != nil { + return nil, osqueryError{message: "internal error: marshal pack JSON: " + err.Error()} + } + config["packs"] = json.RawMessage(packJSON) + } + + // Save interval values if they have been updated. + intervalsModified := false + intervals := fleet.HostOsqueryIntervals{ + DistributedInterval: host.DistributedInterval, + ConfigTLSRefresh: host.ConfigTLSRefresh, + LoggerTLSPeriod: host.LoggerTLSPeriod, + } + if options, ok := config["options"].(map[string]interface{}); ok { + distributedIntervalVal, ok := options["distributed_interval"] + distributedInterval, err := cast.ToUintE(distributedIntervalVal) + if ok && err == nil && intervals.DistributedInterval != distributedInterval { + intervals.DistributedInterval = distributedInterval + intervalsModified = true + } + + loggerTLSPeriodVal, ok := options["logger_tls_period"] + loggerTLSPeriod, err := cast.ToUintE(loggerTLSPeriodVal) + if ok && err == nil && intervals.LoggerTLSPeriod != loggerTLSPeriod { + intervals.LoggerTLSPeriod = loggerTLSPeriod + intervalsModified = true + } + + // Note config_tls_refresh can only be set in the osquery flags (and has + // also been deprecated in osquery for quite some time) so is ignored + // here. + configRefreshVal, ok := options["config_refresh"] + configRefresh, err := cast.ToUintE(configRefreshVal) + if ok && err == nil && intervals.ConfigTLSRefresh != configRefresh { + intervals.ConfigTLSRefresh = configRefresh + intervalsModified = true + } + } + + // We are not doing deferred update host like in other places because the intervals + // are not modified often. + if intervalsModified { + if err := svc.ds.UpdateHostOsqueryIntervals(ctx, host.ID, intervals); err != nil { + return nil, osqueryError{message: "internal error: update host intervals: " + err.Error()} + } + } + + return config, nil +} + +// AgentOptionsForHost gets the agent options for the provided host. +// The host information should be used for filtering based on team, platform, etc. +func (svc *Service) AgentOptionsForHost(ctx context.Context, hostTeamID *uint, hostPlatform string) (json.RawMessage, error) { + // Team agent options have priority over global options. + if hostTeamID != nil { + teamAgentOptions, err := svc.ds.TeamAgentOptions(ctx, *hostTeamID) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "load team agent options for host") + } + + if teamAgentOptions != nil && len(*teamAgentOptions) > 0 { + var options fleet.AgentOptions + if err := json.Unmarshal(*teamAgentOptions, &options); err != nil { + return nil, ctxerr.Wrap(ctx, err, "unmarshal team agent options") + } + return options.ForPlatform(hostPlatform), nil + } + } + // Otherwise return the appropriate override for global options. + appConfig, err := svc.ds.AppConfig(ctx) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "load global agent options") + } + var options fleet.AgentOptions + if appConfig.AgentOptions != nil { + if err := json.Unmarshal(*appConfig.AgentOptions, &options); err != nil { + return nil, ctxerr.Wrap(ctx, err, "unmarshal global agent options") + } + } + return options.ForPlatform(hostPlatform), nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Get Distributed Queries +//////////////////////////////////////////////////////////////////////////////// + +type getDistributedQueriesRequest struct { + NodeKey string `json:"node_key"` +} + +type getDistributedQueriesResponse struct { + Queries map[string]string `json:"queries"` + Accelerate uint `json:"accelerate,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r getDistributedQueriesResponse) error() error { return r.Err } + +func getDistributedQueriesEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + queries, accelerate, err := svc.GetDistributedQueries(ctx) + if err != nil { + return getDistributedQueriesResponse{Err: err}, nil + } + return getDistributedQueriesResponse{Queries: queries, Accelerate: accelerate}, nil +} + +func (svc *Service) GetDistributedQueries(ctx context.Context) (map[string]string, uint, error) { + // skipauth: Authorization is currently for user endpoints only. + svc.authz.SkipAuthorization(ctx) + + host, ok := hostctx.FromContext(ctx) + if !ok { + return nil, 0, osqueryError{message: "internal error: missing host from request context"} + } + + queries := make(map[string]string) + + detailQueries, err := svc.detailQueriesForHost(ctx, host) + if err != nil { + return nil, 0, osqueryError{message: err.Error()} + } + for name, query := range detailQueries { + queries[name] = query + } + + labelQueries, err := svc.labelQueriesForHost(ctx, host) + if err != nil { + return nil, 0, osqueryError{message: err.Error()} + } + for name, query := range labelQueries { + queries[hostLabelQueryPrefix+name] = query + } + + if liveQueries, err := svc.liveQueryStore.QueriesForHost(host.ID); err != nil { + // If the live query store fails to fetch queries we still want the hosts + // to receive all the other queries (details, policies, labels, etc.), + // thus we just log the error. + level.Error(svc.logger).Log("op", "QueriesForHost", "err", err) + } else { + for name, query := range liveQueries { + queries[hostDistributedQueryPrefix+name] = query + } + } + + policyQueries, err := svc.policyQueriesForHost(ctx, host) + if err != nil { + return nil, 0, osqueryError{message: err.Error()} + } + for name, query := range policyQueries { + queries[hostPolicyQueryPrefix+name] = query + } + + accelerate := uint(0) + if host.Hostname == "" || host.Platform == "" { + // Assume this host is just enrolling, and accelerate checkins + // (to allow for platform restricted labels to run quickly + // after platform is retrieved from details) + accelerate = 10 + } + + return queries, accelerate, nil +} + +// detailQueriesForHost returns the map of detail+additional queries that should be executed by +// osqueryd to fill in the host details. +func (svc *Service) detailQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) { + if !svc.shouldUpdate(host.DetailUpdatedAt, svc.config.Osquery.DetailUpdateInterval, host.ID) && !host.RefetchRequested { + return nil, nil + } + + config, err := svc.ds.AppConfig(ctx) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "read app config") + } + + queries := make(map[string]string) + detailQueries := osquery_utils.GetDetailQueries(config, svc.config) + for name, query := range detailQueries { + if query.RunsForPlatform(host.Platform) { + queries[hostDetailQueryPrefix+name] = query.Query + } + } + + if config.HostSettings.AdditionalQueries == nil { + // No additional queries set + return queries, nil + } + + var additionalQueries map[string]string + if err := json.Unmarshal(*config.HostSettings.AdditionalQueries, &additionalQueries); err != nil { + return nil, ctxerr.Wrap(ctx, err, "unmarshal additional queries") + } + + for name, query := range additionalQueries { + queries[hostAdditionalQueryPrefix+name] = query + } + + return queries, nil +} + +func (svc *Service) shouldUpdate(lastUpdated time.Time, interval time.Duration, hostID uint) bool { + svc.jitterMu.Lock() + defer svc.jitterMu.Unlock() + + if svc.jitterH[interval] == nil { + svc.jitterH[interval] = newJitterHashTable(int(int64(svc.config.Osquery.MaxJitterPercent) * int64(interval.Minutes()) / 100.0)) + level.Debug(svc.logger).Log("jitter", "created", "bucketCount", svc.jitterH[interval].bucketCount) + } + + jitter := svc.jitterH[interval].jitterForHost(hostID) + cutoff := svc.clock.Now().Add(-(interval + jitter)) + return lastUpdated.Before(cutoff) +} + +func (svc *Service) labelQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) { + labelReportedAt := svc.task.GetHostLabelReportedAt(ctx, host) + if !svc.shouldUpdate(labelReportedAt, svc.config.Osquery.LabelUpdateInterval, host.ID) && !host.RefetchRequested { + return nil, nil + } + labelQueries, err := svc.ds.LabelQueriesForHost(ctx, host) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "retrieve label queries") + } + return labelQueries, nil +} + +func (svc *Service) policyQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) { + policyReportedAt := svc.task.GetHostPolicyReportedAt(ctx, host) + if !svc.shouldUpdate(policyReportedAt, svc.config.Osquery.PolicyUpdateInterval, host.ID) && !host.RefetchRequested { + return nil, nil + } + policyQueries, err := svc.ds.PolicyQueriesForHost(ctx, host) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "retrieve policy queries") + } + return policyQueries, nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Write Distributed Query Results +//////////////////////////////////////////////////////////////////////////////// + +// When a distributed query has no results, the JSON schema is +// inconsistent, so we use this shim and massage into a consistent +// schema. For example (simplified from actual osqueryd 1.8.2 output): +// { +// "queries": { +// "query_with_no_results": "", // <- Note string instead of array +// "query_with_results": [{"foo":"bar","baz":"bang"}] +// }, +// "node_key":"IGXCXknWQ1baTa8TZ6rF3kAPZ4\/aTsui" +// } +type submitDistributedQueryResultsRequestShim struct { + NodeKey string `json:"node_key"` + Results map[string]json.RawMessage `json:"queries"` + Statuses map[string]interface{} `json:"statuses"` + Messages map[string]string `json:"messages"` +} + +func (shim *submitDistributedQueryResultsRequestShim) toRequest(ctx context.Context) (*SubmitDistributedQueryResultsRequest, error) { + results := fleet.OsqueryDistributedQueryResults{} + for query, raw := range shim.Results { + queryResults := []map[string]string{} + // No need to handle error because the empty array is what we + // want if there was an error parsing the JSON (the error + // indicates that osquery sent us incosistently schemaed JSON) + _ = json.Unmarshal(raw, &queryResults) + results[query] = queryResults + } + + // Statuses were represented by strings in osquery < 3.0 and now + // integers in osquery > 3.0. Massage to string for compatibility with + // the service definition. + statuses := map[string]fleet.OsqueryStatus{} + for query, status := range shim.Statuses { + switch s := status.(type) { + case string: + sint, err := strconv.Atoi(s) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "parse status to int") + } + statuses[query] = fleet.OsqueryStatus(sint) + case float64: + statuses[query] = fleet.OsqueryStatus(s) + default: + return nil, ctxerr.Errorf(ctx, "query status should be string or number, got %T", s) + } + } + + return &SubmitDistributedQueryResultsRequest{ + NodeKey: shim.NodeKey, + Results: results, + Statuses: statuses, + Messages: shim.Messages, + }, nil +} + +type SubmitDistributedQueryResultsRequest struct { + NodeKey string `json:"node_key"` + Results fleet.OsqueryDistributedQueryResults `json:"queries"` + Statuses map[string]fleet.OsqueryStatus `json:"statuses"` + Messages map[string]string `json:"messages"` +} + +type submitDistributedQueryResultsResponse struct { + Err error `json:"error,omitempty"` +} + +func (r submitDistributedQueryResultsResponse) error() error { return r.Err } + +func submitDistributedQueryResultsEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + shim := request.(*submitDistributedQueryResultsRequestShim) + req, err := shim.toRequest(ctx) + if err != nil { + return submitDistributedQueryResultsResponse{Err: err}, nil + } + + err = svc.SubmitDistributedQueryResults(ctx, req.Results, req.Statuses, req.Messages) + if err != nil { + return submitDistributedQueryResultsResponse{Err: err}, nil + } + return submitDistributedQueryResultsResponse{}, nil +} + +const ( + // hostLabelQueryPrefix is appended before the query name when a query is + // provided as a label query. This allows the results to be retrieved when + // osqueryd writes the distributed query results. + hostLabelQueryPrefix = "fleet_label_query_" + + // hostDetailQueryPrefix is appended before the query name when a query is + // provided as a detail query. + hostDetailQueryPrefix = "fleet_detail_query_" + + // hostAdditionalQueryPrefix is appended before the query name when a query is + // provided as an additional query (additional info for hosts to retrieve). + hostAdditionalQueryPrefix = "fleet_additional_query_" + + // hostPolicyQueryPrefix is appended before the query name when a query is + // provided as a policy query. This allows the results to be retrieved when + // osqueryd writes the distributed query results. + hostPolicyQueryPrefix = "fleet_policy_query_" + + // hostDistributedQueryPrefix is appended before the query name when a query is + // run from a distributed query campaign + hostDistributedQueryPrefix = "fleet_distributed_query_" +) + +func (svc *Service) SubmitDistributedQueryResults( + ctx context.Context, + results fleet.OsqueryDistributedQueryResults, + statuses map[string]fleet.OsqueryStatus, + messages map[string]string, +) error { + // skipauth: Authorization is currently for user endpoints only. + svc.authz.SkipAuthorization(ctx) + + host, ok := hostctx.FromContext(ctx) + if !ok { + return osqueryError{message: "internal error: missing host from request context"} + } + + detailUpdated := false + additionalResults := make(fleet.OsqueryDistributedQueryResults) + additionalUpdated := false + labelResults := map[uint]*bool{} + policyResults := map[uint]*bool{} + + svc.maybeDebugHost(ctx, host, results, statuses, messages) + + for query, rows := range results { + // osquery docs say any nonzero (string) value for status indicates a query error + status, ok := statuses[query] + failed := ok && status != fleet.StatusOK + if failed && messages[query] != "" && !noSuchTableRegexp.MatchString(messages[query]) { + level.Debug(svc.logger).Log("query", query, "message", messages[query]) + } + var err error + switch { + case strings.HasPrefix(query, hostDetailQueryPrefix): + trimmedQuery := strings.TrimPrefix(query, hostDetailQueryPrefix) + var ingested bool + ingested, err = svc.directIngestDetailQuery(ctx, host, trimmedQuery, rows, failed) + if !ingested && err == nil { + err = svc.ingestDetailQuery(ctx, host, trimmedQuery, rows) + // No err != nil check here because ingestDetailQuery could have updated + // successfully some values of host. + detailUpdated = true + } + case strings.HasPrefix(query, hostAdditionalQueryPrefix): + name := strings.TrimPrefix(query, hostAdditionalQueryPrefix) + additionalResults[name] = rows + additionalUpdated = true + case strings.HasPrefix(query, hostLabelQueryPrefix): + err = ingestMembershipQuery(hostLabelQueryPrefix, query, rows, labelResults, failed) + case strings.HasPrefix(query, hostPolicyQueryPrefix): + err = ingestMembershipQuery(hostPolicyQueryPrefix, query, rows, policyResults, failed) + case strings.HasPrefix(query, hostDistributedQueryPrefix): + err = svc.ingestDistributedQuery(ctx, *host, query, rows, failed, messages[query]) + default: + err = osqueryError{message: "unknown query prefix: " + query} + } + + if err != nil { + logging.WithErr(ctx, ctxerr.New(ctx, "error in query ingestion")) + logging.WithExtras(ctx, "ingestion-err", err) + } + } + + ac, err := svc.ds.AppConfig(ctx) + if err != nil { + return ctxerr.Wrap(ctx, err, "getting app config") + } + + if len(labelResults) > 0 { + if err := svc.task.RecordLabelQueryExecutions(ctx, host, labelResults, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil { + logging.WithErr(ctx, err) + } + } + + if len(policyResults) > 0 { + if ac.WebhookSettings.FailingPoliciesWebhook.Enable { + incomingResults := filterPolicyResults(policyResults, ac.WebhookSettings.FailingPoliciesWebhook.PolicyIDs) + if failingPolicies, passingPolicies, err := svc.ds.FlippingPoliciesForHost(ctx, host.ID, incomingResults); err != nil { + logging.WithErr(ctx, err) + } else { + // Register the flipped policies on a goroutine to not block the hosts on redis requests. + go func() { + if err := svc.registerFlippedPolicies(ctx, host.ID, host.Hostname, failingPolicies, passingPolicies); err != nil { + logging.WithErr(ctx, err) + } + }() + } + } + // NOTE(mna): currently, failing policies webhook wouldn't see the new + // flipped policies on the next run if async processing is enabled and the + // collection has not been done yet (not persisted in mysql). Should + // FlippingPoliciesForHost take pending redis data into consideration, or + // maybe we should impose restrictions between async collection interval + // and policy update interval? + + if err := svc.task.RecordPolicyQueryExecutions(ctx, host, policyResults, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil { + logging.WithErr(ctx, err) + } + } + + if additionalUpdated { + additionalJSON, err := json.Marshal(additionalResults) + if err != nil { + logging.WithErr(ctx, err) + } else { + additional := json.RawMessage(additionalJSON) + if err := svc.ds.SaveHostAdditional(ctx, host.ID, &additional); err != nil { + logging.WithErr(ctx, err) + } + } + } + + if detailUpdated { + host.DetailUpdatedAt = svc.clock.Now() + } + + refetchRequested := host.RefetchRequested + if refetchRequested { + host.RefetchRequested = false + } + + if refetchRequested || detailUpdated { + appConfig, err := svc.ds.AppConfig(ctx) + if err != nil { + logging.WithErr(ctx, err) + } else { + if appConfig.ServerSettings.DeferredSaveHost { + go svc.serialUpdateHost(host) + } else { + if err := svc.ds.UpdateHost(ctx, host); err != nil { + logging.WithErr(ctx, err) + } + } + } + } + + return nil +} + +var noSuchTableRegexp = regexp.MustCompile(`^no such table: \S+$`) + +func (svc *Service) directIngestDetailQuery(ctx context.Context, host *fleet.Host, name string, rows []map[string]string, failed bool) (ingested bool, err error) { + config, err := svc.ds.AppConfig(ctx) + if err != nil { + return false, osqueryError{message: "ingest detail query: " + err.Error()} + } + + detailQueries := osquery_utils.GetDetailQueries(config, svc.config) + query, ok := detailQueries[name] + if !ok { + return false, osqueryError{message: "unknown detail query " + name} + } + if query.DirectIngestFunc != nil { + err = query.DirectIngestFunc(ctx, svc.logger, host, svc.ds, rows, failed) + if err != nil { + return false, osqueryError{ + message: fmt.Sprintf("ingesting query %s: %s", name, err.Error()), + } + } + return true, nil + } + return false, nil +} + +// ingestDistributedQuery takes the results of a distributed query and modifies the +// provided fleet.Host appropriately. +func (svc *Service) ingestDistributedQuery(ctx context.Context, host fleet.Host, name string, rows []map[string]string, failed bool, errMsg string) error { + trimmedQuery := strings.TrimPrefix(name, hostDistributedQueryPrefix) + + campaignID, err := strconv.Atoi(osquery_utils.EmptyToZero(trimmedQuery)) + if err != nil { + return osqueryError{message: "unable to parse campaign ID: " + trimmedQuery} + } + + // Write the results to the pubsub store + res := fleet.DistributedQueryResult{ + DistributedQueryCampaignID: uint(campaignID), + Host: host, + Rows: rows, + } + if failed { + res.Error = &errMsg + } + + err = svc.resultStore.WriteResult(res) + if err != nil { + var pse pubsub.Error + ok := errors.As(err, &pse) + if !ok || !pse.NoSubscriber() { + return osqueryError{message: "writing results: " + err.Error()} + } + + // If there are no subscribers, the campaign is "orphaned" + // and should be closed so that we don't continue trying to + // execute that query when we can't write to any subscriber + campaign, err := svc.ds.DistributedQueryCampaign(ctx, uint(campaignID)) + if err != nil { + if err := svc.liveQueryStore.StopQuery(strconv.Itoa(campaignID)); err != nil { + return osqueryError{message: "stop orphaned campaign after load failure: " + err.Error()} + } + return osqueryError{message: "loading orphaned campaign: " + err.Error()} + } + + if campaign.CreatedAt.After(svc.clock.Now().Add(-1 * time.Minute)) { + // Give the client a minute to connect before considering the + // campaign orphaned + return osqueryError{message: "campaign waiting for listener (please retry)"} + } + + if campaign.Status != fleet.QueryComplete { + campaign.Status = fleet.QueryComplete + if err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign); err != nil { + return osqueryError{message: "closing orphaned campaign: " + err.Error()} + } + } + + if err := svc.liveQueryStore.StopQuery(strconv.Itoa(campaignID)); err != nil { + return osqueryError{message: "stopping orphaned campaign: " + err.Error()} + } + + // No need to record query completion in this case + return osqueryError{message: "campaign stopped"} + } + + err = svc.liveQueryStore.QueryCompletedByHost(strconv.Itoa(campaignID), host.ID) + if err != nil { + return osqueryError{message: "record query completion: " + err.Error()} + } + + return nil +} + +// ingestMembershipQuery records the results of label queries run by a host +func ingestMembershipQuery( + prefix string, + query string, + rows []map[string]string, + results map[uint]*bool, + failed bool, +) error { + trimmedQuery := strings.TrimPrefix(query, prefix) + trimmedQueryNum, err := strconv.Atoi(osquery_utils.EmptyToZero(trimmedQuery)) + if err != nil { + return fmt.Errorf("converting query from string to int: %w", err) + } + // A label/policy query matches if there is at least one result for that + // query. We must also store negative results. + if failed { + results[uint(trimmedQueryNum)] = nil + } else { + results[uint(trimmedQueryNum)] = ptr.Bool(len(rows) > 0) + } + + return nil +} + +// ingestDetailQuery takes the results of a detail query and modifies the +// provided fleet.Host appropriately. +func (svc *Service) ingestDetailQuery(ctx context.Context, host *fleet.Host, name string, rows []map[string]string) error { + config, err := svc.ds.AppConfig(ctx) + if err != nil { + return osqueryError{message: "ingest detail query: " + err.Error()} + } + + detailQueries := osquery_utils.GetDetailQueries(config, svc.config) + query, ok := detailQueries[name] + if !ok { + return osqueryError{message: "unknown detail query " + name} + } + + if query.IngestFunc != nil { + err = query.IngestFunc(svc.logger, host, rows) + if err != nil { + return osqueryError{ + message: fmt.Sprintf("ingesting query %s: %s", name, err.Error()), + } + } + } + + return nil +} + +// filterPolicyResults filters out policies that aren't configured for webhook automation. +func filterPolicyResults(incoming map[uint]*bool, webhookPolicies []uint) map[uint]*bool { + wp := make(map[uint]struct{}) + for _, policyID := range webhookPolicies { + wp[policyID] = struct{}{} + } + filtered := make(map[uint]*bool) + for policyID, passes := range incoming { + if _, ok := wp[policyID]; !ok { + continue + } + filtered[policyID] = passes + } + return filtered +} + +func (svc *Service) registerFlippedPolicies(ctx context.Context, hostID uint, hostname string, newFailing, newPassing []uint) error { + host := fleet.PolicySetHost{ + ID: hostID, + Hostname: hostname, + } + for _, policyID := range newFailing { + if err := svc.failingPolicySet.AddHost(policyID, host); err != nil { + return err + } + } + for _, policyID := range newPassing { + if err := svc.failingPolicySet.RemoveHosts(policyID, []fleet.PolicySetHost{host}); err != nil { + return err + } + } + return nil +} + +func (svc *Service) maybeDebugHost( + ctx context.Context, + host *fleet.Host, + results fleet.OsqueryDistributedQueryResults, + statuses map[string]fleet.OsqueryStatus, + messages map[string]string, +) { + if svc.debugEnabledForHost(ctx, host.ID) { + hlogger := log.With(svc.logger, "host-id", host.ID) + + logJSON(hlogger, host, "host") + logJSON(hlogger, results, "results") + logJSON(hlogger, statuses, "statuses") + logJSON(hlogger, messages, "messages") + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Submit Logs +//////////////////////////////////////////////////////////////////////////////// + +type submitLogsRequest struct { + NodeKey string `json:"node_key"` + LogType string `json:"log_type"` + Data json.RawMessage `json:"data"` +} + +type submitLogsResponse struct { + Err error `json:"error,omitempty"` +} + +func (r submitLogsResponse) error() error { return r.Err } + +func submitLogsEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*submitLogsRequest) + + var err error + switch req.LogType { + case "status": + var statuses []json.RawMessage + if err := json.Unmarshal(req.Data, &statuses); err != nil { + err = osqueryError{message: "unmarshalling status logs: " + err.Error()} + break + } + + err = svc.SubmitStatusLogs(ctx, statuses) + if err != nil { + break + } + + case "result": + var results []json.RawMessage + if err := json.Unmarshal(req.Data, &results); err != nil { + err = osqueryError{message: "unmarshalling result logs: " + err.Error()} + break + } + err = svc.SubmitResultLogs(ctx, results) + if err != nil { + break + } + + default: + err = osqueryError{message: "unknown log type: " + req.LogType} + } + + return submitLogsResponse{Err: err}, nil +} + +func (svc *Service) SubmitStatusLogs(ctx context.Context, logs []json.RawMessage) error { + // skipauth: Authorization is currently for user endpoints only. + svc.authz.SkipAuthorization(ctx) + + if err := svc.osqueryLogWriter.Status.Write(ctx, logs); err != nil { + return osqueryError{message: "error writing status logs: " + err.Error()} + } + return nil +} + +func (svc *Service) SubmitResultLogs(ctx context.Context, logs []json.RawMessage) error { + // skipauth: Authorization is currently for user endpoints only. + svc.authz.SkipAuthorization(ctx) + + if err := svc.osqueryLogWriter.Result.Write(ctx, logs); err != nil { + return osqueryError{message: "error writing result logs: " + err.Error()} + } + return nil +} diff --git a/server/service/osquery_test.go b/server/service/osquery_test.go new file mode 100644 index 0000000000..b02a8602ba --- /dev/null +++ b/server/service/osquery_test.go @@ -0,0 +1,163 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "testing" + + hostctx "github.com/fleetdm/fleet/v4/server/contexts/host" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/mock" + "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetClientConfig(t *testing.T) { + ds := new(mock.Store) + ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) { + return []*fleet.Pack{}, nil + } + ds.ListScheduledQueriesInPackFunc = func(ctx context.Context, pid uint) ([]*fleet.ScheduledQuery, error) { + tru := true + fals := false + fortytwo := uint(42) + switch pid { + case 1: + return []*fleet.ScheduledQuery{ + {Name: "time", Query: "select * from time", Interval: 30, Removed: &fals}, + }, nil + case 4: + return []*fleet.ScheduledQuery{ + {Name: "foobar", Query: "select 3", Interval: 20, Shard: &fortytwo}, + {Name: "froobing", Query: "select 'guacamole'", Interval: 60, Snapshot: &tru}, + }, nil + default: + return []*fleet.ScheduledQuery{}, nil + } + } + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{AgentOptions: ptr.RawMessage(json.RawMessage(`{"config":{"options":{"baz":"bar"}}}`))}, nil + } + ds.UpdateHostFunc = func(ctx context.Context, host *fleet.Host) error { + return nil + } + ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { + if id != 1 && id != 2 { + return nil, errors.New("not found") + } + return &fleet.Host{ID: id}, nil + } + + svc := newTestService(ds, nil, nil) + + ctx1 := hostctx.NewContext(context.Background(), &fleet.Host{ID: 1}) + ctx2 := hostctx.NewContext(context.Background(), &fleet.Host{ID: 2}) + + expectedOptions := map[string]interface{}{ + "baz": "bar", + } + + expectedConfig := map[string]interface{}{ + "options": expectedOptions, + } + + // No packs loaded yet + conf, err := svc.GetClientConfig(ctx1) + require.NoError(t, err) + assert.Equal(t, expectedConfig, conf) + + conf, err = svc.GetClientConfig(ctx2) + require.NoError(t, err) + assert.Equal(t, expectedConfig, conf) + + // Now add packs + ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) { + switch hid { + case 1: + return []*fleet.Pack{ + {ID: 1, Name: "pack_by_label"}, + {ID: 4, Name: "pack_by_other_label"}, + }, nil + + case 2: + return []*fleet.Pack{ + {ID: 1, Name: "pack_by_label"}, + }, nil + } + return []*fleet.Pack{}, nil + } + + conf, err = svc.GetClientConfig(ctx1) + require.NoError(t, err) + assert.Equal(t, expectedOptions, conf["options"]) + assert.JSONEq(t, `{ + "pack_by_other_label": { + "queries": { + "foobar":{"query":"select 3","interval":20,"shard":42}, + "froobing":{"query":"select 'guacamole'","interval":60,"snapshot":true} + } + }, + "pack_by_label": { + "queries":{ + "time":{"query":"select * from time","interval":30,"removed":false} + } + } + }`, + string(conf["packs"].(json.RawMessage)), + ) + + conf, err = svc.GetClientConfig(ctx2) + require.NoError(t, err) + assert.Equal(t, expectedOptions, conf["options"]) + assert.JSONEq(t, `{ + "pack_by_label": { + "queries":{ + "time":{"query":"select * from time","interval":30,"removed":false} + } + } + }`, + string(conf["packs"].(json.RawMessage)), + ) +} + +func TestAgentOptionsForHost(t *testing.T) { + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + + teamID := uint(1) + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{ + AgentOptions: ptr.RawMessage(json.RawMessage(`{"config":{"baz":"bar"},"overrides":{"platforms":{"darwin":{"foo":"override2"}}}}`)), + }, nil + } + ds.TeamAgentOptionsFunc = func(ctx context.Context, id uint) (*json.RawMessage, error) { + return ptr.RawMessage(json.RawMessage(`{"config":{"foo":"bar"},"overrides":{"platforms":{"darwin":{"foo":"override"}}}}`)), nil + } + + host := &fleet.Host{ + TeamID: &teamID, + Platform: "darwin", + } + + opt, err := svc.AgentOptionsForHost(context.Background(), host.TeamID, host.Platform) + require.NoError(t, err) + assert.JSONEq(t, `{"foo":"override"}`, string(opt)) + + host.Platform = "windows" + opt, err = svc.AgentOptionsForHost(context.Background(), host.TeamID, host.Platform) + require.NoError(t, err) + assert.JSONEq(t, `{"foo":"bar"}`, string(opt)) + + // Should take gobal option with no team + host.TeamID = nil + opt, err = svc.AgentOptionsForHost(context.Background(), host.TeamID, host.Platform) + require.NoError(t, err) + assert.JSONEq(t, `{"baz":"bar"}`, string(opt)) + + host.Platform = "darwin" + opt, err = svc.AgentOptionsForHost(context.Background(), host.TeamID, host.Platform) + require.NoError(t, err) + assert.JSONEq(t, `{"foo":"override2"}`, string(opt)) +} diff --git a/server/service/service_agent_options.go b/server/service/service_agent_options.go deleted file mode 100644 index 04668b48d1..0000000000 --- a/server/service/service_agent_options.go +++ /dev/null @@ -1,41 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - - "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - "github.com/fleetdm/fleet/v4/server/fleet" -) - -// AgentOptionsForHost gets the agent options for the provided host. -// The host information should be used for filtering based on team, platform, etc. -func (svc *Service) AgentOptionsForHost(ctx context.Context, hostTeamID *uint, hostPlatform string) (json.RawMessage, error) { - // Team agent options have priority over global options. - if hostTeamID != nil { - teamAgentOptions, err := svc.ds.TeamAgentOptions(ctx, *hostTeamID) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "load team agent options for host") - } - - if teamAgentOptions != nil && len(*teamAgentOptions) > 0 { - var options fleet.AgentOptions - if err := json.Unmarshal(*teamAgentOptions, &options); err != nil { - return nil, ctxerr.Wrap(ctx, err, "unmarshal team agent options") - } - return options.ForPlatform(hostPlatform), nil - } - } - // Otherwise return the appropriate override for global options. - appConfig, err := svc.ds.AppConfig(ctx) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "load global agent options") - } - var options fleet.AgentOptions - if appConfig.AgentOptions != nil { - if err := json.Unmarshal(*appConfig.AgentOptions, &options); err != nil { - return nil, ctxerr.Wrap(ctx, err, "unmarshal global agent options") - } - } - return options.ForPlatform(hostPlatform), nil -} diff --git a/server/service/service_agent_options_test.go b/server/service/service_agent_options_test.go deleted file mode 100644 index 3d054ed6c9..0000000000 --- a/server/service/service_agent_options_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "testing" - - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/fleetdm/fleet/v4/server/mock" - "github.com/fleetdm/fleet/v4/server/ptr" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestAgentOptionsForHost(t *testing.T) { - ds := new(mock.Store) - svc := newTestService(ds, nil, nil) - - teamID := uint(1) - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{ - AgentOptions: ptr.RawMessage(json.RawMessage(`{"config":{"baz":"bar"},"overrides":{"platforms":{"darwin":{"foo":"override2"}}}}`)), - }, nil - } - ds.TeamAgentOptionsFunc = func(ctx context.Context, id uint) (*json.RawMessage, error) { - return ptr.RawMessage(json.RawMessage(`{"config":{"foo":"bar"},"overrides":{"platforms":{"darwin":{"foo":"override"}}}}`)), nil - } - - host := &fleet.Host{ - TeamID: &teamID, - Platform: "darwin", - } - - opt, err := svc.AgentOptionsForHost(context.Background(), host.TeamID, host.Platform) - require.NoError(t, err) - assert.JSONEq(t, `{"foo":"override"}`, string(opt)) - - host.Platform = "windows" - opt, err = svc.AgentOptionsForHost(context.Background(), host.TeamID, host.Platform) - require.NoError(t, err) - assert.JSONEq(t, `{"foo":"bar"}`, string(opt)) - - // Should take gobal option with no team - host.TeamID = nil - opt, err = svc.AgentOptionsForHost(context.Background(), host.TeamID, host.Platform) - require.NoError(t, err) - assert.JSONEq(t, `{"baz":"bar"}`, string(opt)) - - host.Platform = "darwin" - opt, err = svc.AgentOptionsForHost(context.Background(), host.TeamID, host.Platform) - require.NoError(t, err) - assert.JSONEq(t, `{"foo":"override2"}`, string(opt)) -} diff --git a/server/service/service_carves.go b/server/service/service_carves.go index 1d53b70390..5da8ba10fd 100644 --- a/server/service/service_carves.go +++ b/server/service/service_carves.go @@ -4,72 +4,11 @@ import ( "context" "errors" "fmt" - "time" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - hostctx "github.com/fleetdm/fleet/v4/server/contexts/host" "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/google/uuid" ) -const ( - maxCarveSize = 8 * 1024 * 1024 * 1024 // 8GB - maxBlockSize = 256 * 1024 * 1024 // 256MB -) - -func (svc *Service) CarveBegin(ctx context.Context, payload fleet.CarveBeginPayload) (*fleet.CarveMetadata, error) { - // skipauth: Authorization is currently for user endpoints only. - svc.authz.SkipAuthorization(ctx) - - host, ok := hostctx.FromContext(ctx) - if !ok { - return nil, osqueryError{message: "internal error: missing host from request context"} - } - - if payload.CarveSize == 0 { - return nil, osqueryError{message: "carve_size must be greater than 0"} - } - - if payload.BlockSize > maxBlockSize { - return nil, osqueryError{message: "block_size exceeds max"} - } - if payload.CarveSize > maxCarveSize { - return nil, osqueryError{message: "carve_size exceeds max"} - } - - // The carve should have a total size that fits appropriately into the - // number of blocks of the specified size. - if payload.CarveSize <= (payload.BlockCount-1)*payload.BlockSize || - payload.CarveSize > payload.BlockCount*payload.BlockSize { - return nil, osqueryError{message: "carve_size does not match block_size and block_count"} - } - - sessionId, err := uuid.NewRandom() - if err != nil { - return nil, osqueryError{message: "internal error: generate session ID for carve: " + err.Error()} - } - - now := time.Now().UTC() - carve := &fleet.CarveMetadata{ - Name: fmt.Sprintf("%s-%s-%s", host.Hostname, now.Format(time.RFC3339), payload.RequestId), - HostId: host.ID, - BlockCount: payload.BlockCount, - BlockSize: payload.BlockSize, - CarveSize: payload.CarveSize, - CarveId: payload.CarveId, - RequestId: payload.RequestId, - SessionId: sessionId.String(), - CreatedAt: now, - } - - carve, err = svc.carveStore.NewCarve(ctx, carve) - if err != nil { - return nil, osqueryError{message: "internal error: new carve: " + err.Error()} - } - - return carve, nil -} - func (svc *Service) CarveBlock(ctx context.Context, payload fleet.CarveBlockPayload) error { // skipauth: Authorization is currently for user endpoints only. svc.authz.SkipAuthorization(ctx) diff --git a/server/service/service_osquery.go b/server/service/service_osquery.go index a885d3f125..00f005c847 100644 --- a/server/service/service_osquery.go +++ b/server/service/service_osquery.go @@ -2,27 +2,17 @@ package service import ( "context" - "encoding/json" - "errors" - "fmt" - "regexp" - "strconv" - "strings" "sync" "sync/atomic" "time" "github.com/fleetdm/fleet/v4/server" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - hostctx "github.com/fleetdm/fleet/v4/server/contexts/host" "github.com/fleetdm/fleet/v4/server/contexts/logging" "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/fleetdm/fleet/v4/server/ptr" - "github.com/fleetdm/fleet/v4/server/pubsub" "github.com/fleetdm/fleet/v4/server/service/osquery_utils" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" - "github.com/spf13/cast" ) type osqueryError struct { @@ -248,219 +238,6 @@ func getHostIdentifier(logger log.Logger, identifierOption, providedIdentifier s return providedIdentifier } -func (svc *Service) GetClientConfig(ctx context.Context) (map[string]interface{}, error) { - // skipauth: Authorization is currently for user endpoints only. - svc.authz.SkipAuthorization(ctx) - - host, ok := hostctx.FromContext(ctx) - if !ok { - return nil, osqueryError{message: "internal error: missing host from request context"} - } - - baseConfig, err := svc.AgentOptionsForHost(ctx, host.TeamID, host.Platform) - if err != nil { - return nil, osqueryError{message: "internal error: fetch base config: " + err.Error()} - } - - config := make(map[string]interface{}) - if baseConfig != nil { - err = json.Unmarshal(baseConfig, &config) - if err != nil { - return nil, osqueryError{message: "internal error: parse base configuration: " + err.Error()} - } - } - - packs, err := svc.ds.ListPacksForHost(ctx, host.ID) - if err != nil { - return nil, osqueryError{message: "database error: " + err.Error()} - } - - packConfig := fleet.Packs{} - for _, pack := range packs { - // first, we must figure out what queries are in this pack - queries, err := svc.ds.ListScheduledQueriesInPack(ctx, pack.ID) - if err != nil { - return nil, osqueryError{message: "database error: " + err.Error()} - } - - // the serializable osquery config struct expects content in a - // particular format, so we do the conversion here - configQueries := fleet.Queries{} - for _, query := range queries { - queryContent := fleet.QueryContent{ - Query: query.Query, - Interval: query.Interval, - Platform: query.Platform, - Version: query.Version, - Removed: query.Removed, - Shard: query.Shard, - Denylist: query.Denylist, - } - - if query.Removed != nil { - queryContent.Removed = query.Removed - } - - if query.Snapshot != nil && *query.Snapshot { - queryContent.Snapshot = query.Snapshot - } - - configQueries[query.Name] = queryContent - } - - // finally, we add the pack to the client config struct with all of - // the pack's queries - packConfig[pack.Name] = fleet.PackContent{ - Platform: pack.Platform, - Queries: configQueries, - } - } - - if len(packConfig) > 0 { - packJSON, err := json.Marshal(packConfig) - if err != nil { - return nil, osqueryError{message: "internal error: marshal pack JSON: " + err.Error()} - } - config["packs"] = json.RawMessage(packJSON) - } - - // Save interval values if they have been updated. - intervalsModified := false - intervals := fleet.HostOsqueryIntervals{ - DistributedInterval: host.DistributedInterval, - ConfigTLSRefresh: host.ConfigTLSRefresh, - LoggerTLSPeriod: host.LoggerTLSPeriod, - } - if options, ok := config["options"].(map[string]interface{}); ok { - distributedIntervalVal, ok := options["distributed_interval"] - distributedInterval, err := cast.ToUintE(distributedIntervalVal) - if ok && err == nil && intervals.DistributedInterval != distributedInterval { - intervals.DistributedInterval = distributedInterval - intervalsModified = true - } - - loggerTLSPeriodVal, ok := options["logger_tls_period"] - loggerTLSPeriod, err := cast.ToUintE(loggerTLSPeriodVal) - if ok && err == nil && intervals.LoggerTLSPeriod != loggerTLSPeriod { - intervals.LoggerTLSPeriod = loggerTLSPeriod - intervalsModified = true - } - - // Note config_tls_refresh can only be set in the osquery flags (and has - // also been deprecated in osquery for quite some time) so is ignored - // here. - configRefreshVal, ok := options["config_refresh"] - configRefresh, err := cast.ToUintE(configRefreshVal) - if ok && err == nil && intervals.ConfigTLSRefresh != configRefresh { - intervals.ConfigTLSRefresh = configRefresh - intervalsModified = true - } - } - - // We are not doing deferred update host like in other places because the intervals - // are not modified often. - if intervalsModified { - if err := svc.ds.UpdateHostOsqueryIntervals(ctx, host.ID, intervals); err != nil { - return nil, osqueryError{message: "internal error: update host intervals: " + err.Error()} - } - } - - return config, nil -} - -func (svc *Service) SubmitStatusLogs(ctx context.Context, logs []json.RawMessage) error { - // skipauth: Authorization is currently for user endpoints only. - svc.authz.SkipAuthorization(ctx) - - if err := svc.osqueryLogWriter.Status.Write(ctx, logs); err != nil { - return osqueryError{message: "error writing status logs: " + err.Error()} - } - return nil -} - -func (svc *Service) SubmitResultLogs(ctx context.Context, logs []json.RawMessage) error { - // skipauth: Authorization is currently for user endpoints only. - svc.authz.SkipAuthorization(ctx) - - if err := svc.osqueryLogWriter.Result.Write(ctx, logs); err != nil { - return osqueryError{message: "error writing result logs: " + err.Error()} - } - return nil -} - -// hostLabelQueryPrefix is appended before the query name when a query is -// provided as a label query. This allows the results to be retrieved when -// osqueryd writes the distributed query results. -const hostLabelQueryPrefix = "fleet_label_query_" - -// hostDetailQueryPrefix is appended before the query name when a query is -// provided as a detail query. -const hostDetailQueryPrefix = "fleet_detail_query_" - -// hostAdditionalQueryPrefix is appended before the query name when a query is -// provided as an additional query (additional info for hosts to retrieve). -const hostAdditionalQueryPrefix = "fleet_additional_query_" - -// hostPolicyQueryPrefix is appended before the query name when a query is -// provided as a policy query. This allows the results to be retrieved when -// osqueryd writes the distributed query results. -const hostPolicyQueryPrefix = "fleet_policy_query_" - -// hostDistributedQueryPrefix is appended before the query name when a query is -// run from a distributed query campaign -const hostDistributedQueryPrefix = "fleet_distributed_query_" - -// detailQueriesForHost returns the map of detail+additional queries that should be executed by -// osqueryd to fill in the host details. -func (svc *Service) detailQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) { - if !svc.shouldUpdate(host.DetailUpdatedAt, svc.config.Osquery.DetailUpdateInterval, host.ID) && !host.RefetchRequested { - return nil, nil - } - - config, err := svc.ds.AppConfig(ctx) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "read app config") - } - - queries := make(map[string]string) - detailQueries := osquery_utils.GetDetailQueries(config, svc.config) - for name, query := range detailQueries { - if query.RunsForPlatform(host.Platform) { - queries[hostDetailQueryPrefix+name] = query.Query - } - } - - if config.HostSettings.AdditionalQueries == nil { - // No additional queries set - return queries, nil - } - - var additionalQueries map[string]string - if err := json.Unmarshal(*config.HostSettings.AdditionalQueries, &additionalQueries); err != nil { - return nil, ctxerr.Wrap(ctx, err, "unmarshal additional queries") - } - - for name, query := range additionalQueries { - queries[hostAdditionalQueryPrefix+name] = query - } - - return queries, nil -} - -func (svc *Service) shouldUpdate(lastUpdated time.Time, interval time.Duration, hostID uint) bool { - svc.jitterMu.Lock() - defer svc.jitterMu.Unlock() - - if svc.jitterH[interval] == nil { - svc.jitterH[interval] = newJitterHashTable(int(int64(svc.config.Osquery.MaxJitterPercent) * int64(interval.Minutes()) / 100.0)) - level.Debug(svc.logger).Log("jitter", "created", "bucketCount", svc.jitterH[interval].bucketCount) - } - - jitter := svc.jitterH[interval].jitterForHost(hostID) - cutoff := svc.clock.Now().Add(-(interval + jitter)) - return lastUpdated.Before(cutoff) -} - // jitterHashTable implements a data structure that allows a fleet to generate a static jitter value // that is properly balanced. Balance in this context means that hosts would be distributed uniformly // across the total jitter time so there are no spikes. @@ -538,415 +315,3 @@ func (jh *jitterHashTable) jitterForHost(hostID uint) time.Duration { jh.mu.Unlock() return jh.jitterForHost(hostID) } - -func (svc *Service) labelQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) { - labelReportedAt := svc.task.GetHostLabelReportedAt(ctx, host) - if !svc.shouldUpdate(labelReportedAt, svc.config.Osquery.LabelUpdateInterval, host.ID) && !host.RefetchRequested { - return nil, nil - } - labelQueries, err := svc.ds.LabelQueriesForHost(ctx, host) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "retrieve label queries") - } - return labelQueries, nil -} - -func (svc *Service) policyQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) { - policyReportedAt := svc.task.GetHostPolicyReportedAt(ctx, host) - if !svc.shouldUpdate(policyReportedAt, svc.config.Osquery.PolicyUpdateInterval, host.ID) && !host.RefetchRequested { - return nil, nil - } - policyQueries, err := svc.ds.PolicyQueriesForHost(ctx, host) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "retrieve policy queries") - } - return policyQueries, nil -} - -func (svc *Service) GetDistributedQueries(ctx context.Context) (map[string]string, uint, error) { - // skipauth: Authorization is currently for user endpoints only. - svc.authz.SkipAuthorization(ctx) - - host, ok := hostctx.FromContext(ctx) - if !ok { - return nil, 0, osqueryError{message: "internal error: missing host from request context"} - } - - queries := make(map[string]string) - - detailQueries, err := svc.detailQueriesForHost(ctx, host) - if err != nil { - return nil, 0, osqueryError{message: err.Error()} - } - for name, query := range detailQueries { - queries[name] = query - } - - labelQueries, err := svc.labelQueriesForHost(ctx, host) - if err != nil { - return nil, 0, osqueryError{message: err.Error()} - } - for name, query := range labelQueries { - queries[hostLabelQueryPrefix+name] = query - } - - if liveQueries, err := svc.liveQueryStore.QueriesForHost(host.ID); err != nil { - // If the live query store fails to fetch queries we still want the hosts - // to receive all the other queries (details, policies, labels, etc.), - // thus we just log the error. - level.Error(svc.logger).Log("op", "QueriesForHost", "err", err) - } else { - for name, query := range liveQueries { - queries[hostDistributedQueryPrefix+name] = query - } - } - - policyQueries, err := svc.policyQueriesForHost(ctx, host) - if err != nil { - return nil, 0, osqueryError{message: err.Error()} - } - for name, query := range policyQueries { - queries[hostPolicyQueryPrefix+name] = query - } - - accelerate := uint(0) - if host.Hostname == "" || host.Platform == "" { - // Assume this host is just enrolling, and accelerate checkins - // (to allow for platform restricted labels to run quickly - // after platform is retrieved from details) - accelerate = 10 - } - - return queries, accelerate, nil -} - -// ingestDetailQuery takes the results of a detail query and modifies the -// provided fleet.Host appropriately. -func (svc *Service) ingestDetailQuery(ctx context.Context, host *fleet.Host, name string, rows []map[string]string) error { - config, err := svc.ds.AppConfig(ctx) - if err != nil { - return osqueryError{message: "ingest detail query: " + err.Error()} - } - - detailQueries := osquery_utils.GetDetailQueries(config, svc.config) - query, ok := detailQueries[name] - if !ok { - return osqueryError{message: "unknown detail query " + name} - } - - if query.IngestFunc != nil { - err = query.IngestFunc(svc.logger, host, rows) - if err != nil { - return osqueryError{ - message: fmt.Sprintf("ingesting query %s: %s", name, err.Error()), - } - } - } - - return nil -} - -// ingestMembershipQuery records the results of label queries run by a host -func ingestMembershipQuery( - prefix string, - query string, - rows []map[string]string, - results map[uint]*bool, - failed bool, -) error { - trimmedQuery := strings.TrimPrefix(query, prefix) - trimmedQueryNum, err := strconv.Atoi(osquery_utils.EmptyToZero(trimmedQuery)) - if err != nil { - return fmt.Errorf("converting query from string to int: %w", err) - } - // A label/policy query matches if there is at least one result for that - // query. We must also store negative results. - if failed { - results[uint(trimmedQueryNum)] = nil - } else { - results[uint(trimmedQueryNum)] = ptr.Bool(len(rows) > 0) - } - - return nil -} - -// ingestDistributedQuery takes the results of a distributed query and modifies the -// provided fleet.Host appropriately. -func (svc *Service) ingestDistributedQuery(ctx context.Context, host fleet.Host, name string, rows []map[string]string, failed bool, errMsg string) error { - trimmedQuery := strings.TrimPrefix(name, hostDistributedQueryPrefix) - - campaignID, err := strconv.Atoi(osquery_utils.EmptyToZero(trimmedQuery)) - if err != nil { - return osqueryError{message: "unable to parse campaign ID: " + trimmedQuery} - } - - // Write the results to the pubsub store - res := fleet.DistributedQueryResult{ - DistributedQueryCampaignID: uint(campaignID), - Host: host, - Rows: rows, - } - if failed { - res.Error = &errMsg - } - - err = svc.resultStore.WriteResult(res) - if err != nil { - var pse pubsub.Error - ok := errors.As(err, &pse) - if !ok || !pse.NoSubscriber() { - return osqueryError{message: "writing results: " + err.Error()} - } - - // If there are no subscribers, the campaign is "orphaned" - // and should be closed so that we don't continue trying to - // execute that query when we can't write to any subscriber - campaign, err := svc.ds.DistributedQueryCampaign(ctx, uint(campaignID)) - if err != nil { - if err := svc.liveQueryStore.StopQuery(strconv.Itoa(campaignID)); err != nil { - return osqueryError{message: "stop orphaned campaign after load failure: " + err.Error()} - } - return osqueryError{message: "loading orphaned campaign: " + err.Error()} - } - - if campaign.CreatedAt.After(svc.clock.Now().Add(-1 * time.Minute)) { - // Give the client a minute to connect before considering the - // campaign orphaned - return osqueryError{message: "campaign waiting for listener (please retry)"} - } - - if campaign.Status != fleet.QueryComplete { - campaign.Status = fleet.QueryComplete - if err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign); err != nil { - return osqueryError{message: "closing orphaned campaign: " + err.Error()} - } - } - - if err := svc.liveQueryStore.StopQuery(strconv.Itoa(campaignID)); err != nil { - return osqueryError{message: "stopping orphaned campaign: " + err.Error()} - } - - // No need to record query completion in this case - return osqueryError{message: "campaign stopped"} - } - - err = svc.liveQueryStore.QueryCompletedByHost(strconv.Itoa(campaignID), host.ID) - if err != nil { - return osqueryError{message: "record query completion: " + err.Error()} - } - - return nil -} - -func (svc *Service) directIngestDetailQuery(ctx context.Context, host *fleet.Host, name string, rows []map[string]string, failed bool) (ingested bool, err error) { - config, err := svc.ds.AppConfig(ctx) - if err != nil { - return false, osqueryError{message: "ingest detail query: " + err.Error()} - } - - detailQueries := osquery_utils.GetDetailQueries(config, svc.config) - query, ok := detailQueries[name] - if !ok { - return false, osqueryError{message: "unknown detail query " + name} - } - if query.DirectIngestFunc != nil { - err = query.DirectIngestFunc(ctx, svc.logger, host, svc.ds, rows, failed) - if err != nil { - return false, osqueryError{ - message: fmt.Sprintf("ingesting query %s: %s", name, err.Error()), - } - } - return true, nil - } - return false, nil -} - -var noSuchTableRegexp = regexp.MustCompile(`^no such table: \S+$`) - -func (svc *Service) SubmitDistributedQueryResults( - ctx context.Context, - results fleet.OsqueryDistributedQueryResults, - statuses map[string]fleet.OsqueryStatus, - messages map[string]string, -) error { - // skipauth: Authorization is currently for user endpoints only. - svc.authz.SkipAuthorization(ctx) - - host, ok := hostctx.FromContext(ctx) - if !ok { - return osqueryError{message: "internal error: missing host from request context"} - } - - detailUpdated := false - additionalResults := make(fleet.OsqueryDistributedQueryResults) - additionalUpdated := false - labelResults := map[uint]*bool{} - policyResults := map[uint]*bool{} - - svc.maybeDebugHost(ctx, host, results, statuses, messages) - - for query, rows := range results { - // osquery docs say any nonzero (string) value for status indicates a query error - status, ok := statuses[query] - failed := ok && status != fleet.StatusOK - if failed && messages[query] != "" && !noSuchTableRegexp.MatchString(messages[query]) { - level.Debug(svc.logger).Log("query", query, "message", messages[query]) - } - var err error - switch { - case strings.HasPrefix(query, hostDetailQueryPrefix): - trimmedQuery := strings.TrimPrefix(query, hostDetailQueryPrefix) - var ingested bool - ingested, err = svc.directIngestDetailQuery(ctx, host, trimmedQuery, rows, failed) - if !ingested && err == nil { - err = svc.ingestDetailQuery(ctx, host, trimmedQuery, rows) - // No err != nil check here because ingestDetailQuery could have updated - // successfully some values of host. - detailUpdated = true - } - case strings.HasPrefix(query, hostAdditionalQueryPrefix): - name := strings.TrimPrefix(query, hostAdditionalQueryPrefix) - additionalResults[name] = rows - additionalUpdated = true - case strings.HasPrefix(query, hostLabelQueryPrefix): - err = ingestMembershipQuery(hostLabelQueryPrefix, query, rows, labelResults, failed) - case strings.HasPrefix(query, hostPolicyQueryPrefix): - err = ingestMembershipQuery(hostPolicyQueryPrefix, query, rows, policyResults, failed) - case strings.HasPrefix(query, hostDistributedQueryPrefix): - err = svc.ingestDistributedQuery(ctx, *host, query, rows, failed, messages[query]) - default: - err = osqueryError{message: "unknown query prefix: " + query} - } - - if err != nil { - logging.WithErr(ctx, ctxerr.New(ctx, "error in query ingestion")) - logging.WithExtras(ctx, "ingestion-err", err) - } - } - - ac, err := svc.ds.AppConfig(ctx) - if err != nil { - return ctxerr.Wrap(ctx, err, "getting app config") - } - - if len(labelResults) > 0 { - if err := svc.task.RecordLabelQueryExecutions(ctx, host, labelResults, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil { - logging.WithErr(ctx, err) - } - } - - if len(policyResults) > 0 { - if ac.WebhookSettings.FailingPoliciesWebhook.Enable { - incomingResults := filterPolicyResults(policyResults, ac.WebhookSettings.FailingPoliciesWebhook.PolicyIDs) - if failingPolicies, passingPolicies, err := svc.ds.FlippingPoliciesForHost(ctx, host.ID, incomingResults); err != nil { - logging.WithErr(ctx, err) - } else { - // Register the flipped policies on a goroutine to not block the hosts on redis requests. - go func() { - if err := svc.registerFlippedPolicies(ctx, host.ID, host.Hostname, failingPolicies, passingPolicies); err != nil { - logging.WithErr(ctx, err) - } - }() - } - } - // NOTE(mna): currently, failing policies webhook wouldn't see the new - // flipped policies on the next run if async processing is enabled and the - // collection has not been done yet (not persisted in mysql). Should - // FlippingPoliciesForHost take pending redis data into consideration, or - // maybe we should impose restrictions between async collection interval - // and policy update interval? - - if err := svc.task.RecordPolicyQueryExecutions(ctx, host, policyResults, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil { - logging.WithErr(ctx, err) - } - } - - if additionalUpdated { - additionalJSON, err := json.Marshal(additionalResults) - if err != nil { - logging.WithErr(ctx, err) - } else { - additional := json.RawMessage(additionalJSON) - if err := svc.ds.SaveHostAdditional(ctx, host.ID, &additional); err != nil { - logging.WithErr(ctx, err) - } - } - } - - if detailUpdated { - host.DetailUpdatedAt = svc.clock.Now() - } - - refetchRequested := host.RefetchRequested - if refetchRequested { - host.RefetchRequested = false - } - - if refetchRequested || detailUpdated { - appConfig, err := svc.ds.AppConfig(ctx) - if err != nil { - logging.WithErr(ctx, err) - } else { - if appConfig.ServerSettings.DeferredSaveHost { - go svc.serialUpdateHost(host) - } else { - if err := svc.ds.UpdateHost(ctx, host); err != nil { - logging.WithErr(ctx, err) - } - } - } - } - - return nil -} - -// filterPolicyResults filters out policies that aren't configured for webhook automation. -func filterPolicyResults(incoming map[uint]*bool, webhookPolicies []uint) map[uint]*bool { - wp := make(map[uint]struct{}) - for _, policyID := range webhookPolicies { - wp[policyID] = struct{}{} - } - filtered := make(map[uint]*bool) - for policyID, passes := range incoming { - if _, ok := wp[policyID]; !ok { - continue - } - filtered[policyID] = passes - } - return filtered -} - -func (svc *Service) registerFlippedPolicies(ctx context.Context, hostID uint, hostname string, newFailing, newPassing []uint) error { - host := fleet.PolicySetHost{ - ID: hostID, - Hostname: hostname, - } - for _, policyID := range newFailing { - if err := svc.failingPolicySet.AddHost(policyID, host); err != nil { - return err - } - } - for _, policyID := range newPassing { - if err := svc.failingPolicySet.RemoveHosts(policyID, []fleet.PolicySetHost{host}); err != nil { - return err - } - } - return nil -} - -func (svc *Service) maybeDebugHost( - ctx context.Context, - host *fleet.Host, - results fleet.OsqueryDistributedQueryResults, - statuses map[string]fleet.OsqueryStatus, - messages map[string]string, -) { - if svc.debugEnabledForHost(ctx, host.ID) { - hlogger := log.With(svc.logger, "host-id", host.ID) - - logJSON(hlogger, host, "host") - logJSON(hlogger, results, "results") - logJSON(hlogger, statuses, "statuses") - logJSON(hlogger, messages, "messages") - } -} diff --git a/server/service/service_osquery_test.go b/server/service/service_osquery_test.go index 30910abdad..244fceb4b8 100644 --- a/server/service/service_osquery_test.go +++ b/server/service/service_osquery_test.go @@ -481,114 +481,6 @@ func TestLabelQueries(t *testing.T) { assert.Zero(t, acc) } -func TestGetClientConfig(t *testing.T) { - ds := new(mock.Store) - ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) { - return []*fleet.Pack{}, nil - } - ds.ListScheduledQueriesInPackFunc = func(ctx context.Context, pid uint) ([]*fleet.ScheduledQuery, error) { - tru := true - fals := false - fortytwo := uint(42) - switch pid { - case 1: - return []*fleet.ScheduledQuery{ - {Name: "time", Query: "select * from time", Interval: 30, Removed: &fals}, - }, nil - case 4: - return []*fleet.ScheduledQuery{ - {Name: "foobar", Query: "select 3", Interval: 20, Shard: &fortytwo}, - {Name: "froobing", Query: "select 'guacamole'", Interval: 60, Snapshot: &tru}, - }, nil - default: - return []*fleet.ScheduledQuery{}, nil - } - } - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{AgentOptions: ptr.RawMessage(json.RawMessage(`{"config":{"options":{"baz":"bar"}}}`))}, nil - } - ds.UpdateHostFunc = func(ctx context.Context, host *fleet.Host) error { - return nil - } - ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { - if id != 1 && id != 2 { - return nil, errors.New("not found") - } - return &fleet.Host{ID: id}, nil - } - - svc := newTestService(ds, nil, nil) - - ctx1 := hostctx.NewContext(context.Background(), &fleet.Host{ID: 1}) - ctx2 := hostctx.NewContext(context.Background(), &fleet.Host{ID: 2}) - - expectedOptions := map[string]interface{}{ - "baz": "bar", - } - - expectedConfig := map[string]interface{}{ - "options": expectedOptions, - } - - // No packs loaded yet - conf, err := svc.GetClientConfig(ctx1) - require.NoError(t, err) - assert.Equal(t, expectedConfig, conf) - - conf, err = svc.GetClientConfig(ctx2) - require.NoError(t, err) - assert.Equal(t, expectedConfig, conf) - - // Now add packs - ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) { - switch hid { - case 1: - return []*fleet.Pack{ - {ID: 1, Name: "pack_by_label"}, - {ID: 4, Name: "pack_by_other_label"}, - }, nil - - case 2: - return []*fleet.Pack{ - {ID: 1, Name: "pack_by_label"}, - }, nil - } - return []*fleet.Pack{}, nil - } - - conf, err = svc.GetClientConfig(ctx1) - require.NoError(t, err) - assert.Equal(t, expectedOptions, conf["options"]) - assert.JSONEq(t, `{ - "pack_by_other_label": { - "queries": { - "foobar":{"query":"select 3","interval":20,"shard":42}, - "froobing":{"query":"select 'guacamole'","interval":60,"snapshot":true} - } - }, - "pack_by_label": { - "queries":{ - "time":{"query":"select * from time","interval":30,"removed":false} - } - } - }`, - string(conf["packs"].(json.RawMessage)), - ) - - conf, err = svc.GetClientConfig(ctx2) - require.NoError(t, err) - assert.Equal(t, expectedOptions, conf["options"]) - assert.JSONEq(t, `{ - "pack_by_label": { - "queries":{ - "time":{"query":"select * from time","interval":30,"removed":false} - } - } - }`, - string(conf["packs"].(json.RawMessage)), - ) -} - func TestDetailQueriesWithEmptyStrings(t *testing.T) { ds := new(mock.Store) mockClock := clock.NewMockClock() diff --git a/server/service/transport_carves.go b/server/service/transport_carves.go index 1629fd73a4..aa9e46f7c0 100644 --- a/server/service/transport_carves.go +++ b/server/service/transport_carves.go @@ -8,17 +8,6 @@ import ( "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" ) -func decodeCarveBeginRequest(ctx context.Context, r *http.Request) (interface{}, error) { - defer r.Body.Close() - - var req carveBeginRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, ctxerr.Wrap(ctx, err, "decoding JSON") - } - - return req, nil -} - func decodeCarveBlockRequest(ctx context.Context, r *http.Request) (interface{}, error) { defer r.Body.Close() diff --git a/server/service/transport_error.go b/server/service/transport_error.go index 666048148d..bbf406caad 100644 --- a/server/service/transport_error.go +++ b/server/service/transport_error.go @@ -119,6 +119,10 @@ func encodeError(ctx context.Context, err error, w http.ResponseWriter) { w.WriteHeader(http.StatusUnauthorized) errMap["node_invalid"] = true } else { + // TODO: osqueryError is not always the result of an internal error on + // our side, it is also used to represent a client error (invalid data, + // e.g. malformed json, carve too large, etc., so 4xx), are we returning + // a 500 because of some osquery-specific requirement? w.WriteHeader(http.StatusInternalServerError) } diff --git a/server/service/transport_osquery.go b/server/service/transport_osquery.go index c5958e1887..3ee0aba0c8 100644 --- a/server/service/transport_osquery.go +++ b/server/service/transport_osquery.go @@ -1,14 +1,9 @@ package service import ( - "compress/gzip" "context" "encoding/json" "net/http" - "strconv" - - "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - "github.com/fleetdm/fleet/v4/server/fleet" ) func decodeEnrollAgentRequest(ctx context.Context, r *http.Request) (interface{}, error) { @@ -20,107 +15,3 @@ func decodeEnrollAgentRequest(ctx context.Context, r *http.Request) (interface{} return req, nil } - -func decodeGetClientConfigRequest(ctx context.Context, r *http.Request) (interface{}, error) { - var req getClientConfigRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, err - } - defer r.Body.Close() - - return req, nil -} - -func decodeGetDistributedQueriesRequest(ctx context.Context, r *http.Request) (interface{}, error) { - var req getDistributedQueriesRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, err - } - defer r.Body.Close() - - return req, nil -} - -func decodeSubmitDistributedQueryResultsRequest(ctx context.Context, r *http.Request) (interface{}, error) { - // When a distributed query has no results, the JSON schema is - // inconsistent, so we use this shim and massage into a consistent - // schema. For example (simplified from actual osqueryd 1.8.2 output): - // { - // "queries": { - // "query_with_no_results": "", // <- Note string instead of array - // "query_with_results": [{"foo":"bar","baz":"bang"}] - // }, - // "node_key":"IGXCXknWQ1baTa8TZ6rF3kAPZ4\/aTsui" - // } - - type distributedQueryResultsShim struct { - NodeKey string `json:"node_key"` - Results map[string]json.RawMessage `json:"queries"` - Statuses map[string]interface{} `json:"statuses"` - Messages map[string]string `json:"messages"` - } - - var shim distributedQueryResultsShim - if err := json.NewDecoder(r.Body).Decode(&shim); err != nil { - return nil, err - } - defer r.Body.Close() - - results := fleet.OsqueryDistributedQueryResults{} - for query, raw := range shim.Results { - queryResults := []map[string]string{} - // No need to handle error because the empty array is what we - // want if there was an error parsing the JSON (the error - // indicates that osquery sent us incosistently schemaed JSON) - _ = json.Unmarshal(raw, &queryResults) - results[query] = queryResults - } - - // Statuses were represented by strings in osquery < 3.0 and now - // integers in osquery > 3.0. Massage to string for compatibility with - // the service definition. - statuses := map[string]fleet.OsqueryStatus{} - for query, status := range shim.Statuses { - switch s := status.(type) { - case string: - sint, err := strconv.Atoi(s) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "parse status to int") - } - statuses[query] = fleet.OsqueryStatus(sint) - case float64: - statuses[query] = fleet.OsqueryStatus(s) - default: - return nil, ctxerr.Errorf(ctx, "query status should be string or number, got %T", s) - } - } - - req := SubmitDistributedQueryResultsRequest{ - NodeKey: shim.NodeKey, - Results: results, - Statuses: statuses, - Messages: shim.Messages, - } - - return req, nil -} - -func decodeSubmitLogsRequest(ctx context.Context, r *http.Request) (interface{}, error) { - var err error - body := r.Body - if r.Header.Get("content-encoding") == "gzip" { - body, err = gzip.NewReader(body) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "decoding gzip") - } - defer body.Close() - } - - var req submitLogsRequest - if err = json.NewDecoder(body).Decode(&req); err != nil { - return nil, ctxerr.Wrap(ctx, err, "decoding JSON") - } - defer r.Body.Close() - - return req, nil -} diff --git a/server/service/transport_osquery_test.go b/server/service/transport_osquery_test.go index 108ee9202b..68503ca205 100644 --- a/server/service/transport_osquery_test.go +++ b/server/service/transport_osquery_test.go @@ -2,13 +2,11 @@ package service import ( "bytes" - "compress/gzip" "context" "net/http" "net/http/httptest" "testing" - "github.com/fleetdm/fleet/v4/server/fleet" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,146 +34,3 @@ func TestDecodeEnrollAgentRequest(t *testing.T) { httptest.NewRequest("POST", "/", &body), ) } - -func TestDecodeGetClientConfigRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeGetClientConfigRequest(context.Background(), request) - require.Nil(t, err) - - params := r.(getClientConfigRequest) - assert.Equal(t, "key", params.NodeKey) - }).Methods("POST") - - var body bytes.Buffer - body.Write([]byte(`{ - "node_key": "key" - }`)) - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("POST", "/", &body), - ) -} - -func TestDecodeGetDistributedQueriesRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeGetDistributedQueriesRequest(context.Background(), request) - require.Nil(t, err) - - params := r.(getDistributedQueriesRequest) - assert.Equal(t, "key", params.NodeKey) - }).Methods("POST") - - var body bytes.Buffer - body.Write([]byte(`{ - "node_key": "key" - }`)) - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("POST", "/", &body), - ) -} - -func TestDecodeSubmitDistributedQueryResultsRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeSubmitDistributedQueryResultsRequest(context.Background(), request) - require.Nil(t, err) - - params := r.(SubmitDistributedQueryResultsRequest) - assert.Equal(t, "key", params.NodeKey) - assert.Equal(t, fleet.OsqueryDistributedQueryResults{ - "id1": { - {"col1": "val1", "col2": "val2"}, - {"col1": "val3", "col2": "val4"}, - }, - "id2": { - {"col3": "val5", "col4": "val6"}, - }, - "id3": {}, - }, params.Results) - assert.Equal(t, map[string]fleet.OsqueryStatus{"id1": 0, "id3": 1}, params.Statuses) - }).Methods("POST") - - // Note we explicitly test the case that requires using the shim - // because of the inconsistent JSON schema - var body bytes.Buffer - body.Write([]byte(`{ - "node_key": "key", - "queries": { - "id1": [ - {"col1": "val1", "col2": "val2"}, - {"col1": "val3", "col2": "val4"} - ], - "id2": [ - {"col3": "val5", "col4": "val6"} - ], - "id3": "" - }, - "statuses": {"id1": 0, "id3": "1"} - }`)) - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("POST", "/", &body), - ) -} - -func TestDecodeSubmitLogsRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeSubmitLogsRequest(context.Background(), request) - require.Nil(t, err) - - params := r.(submitLogsRequest) - assert.Equal(t, "xOCmmaTJJvGRi8prh4kdjkFMyh7K1bXb", params.NodeKey) - assert.Equal(t, "status", params.LogType) - }).Methods("POST") - - bodyJSON := []byte(` - { - "node_key":"xOCmmaTJJvGRi8prh4kdjkFMyh7K1bXb", - "log_type":"status", - "data":[ - { - "severity":"0", - "filename":"tls.cpp", - "line":"205", - "message":"TLS\/HTTPS POST request to URI: https:\/\/dockerhost:8080\/api\/v1\/osquery\/log", - "version":"2.3.2", - "decorations":{ - "host_uuid":"EB714C9D-C1F8-A436-B6DA-3F853C5502EA", - "hostname":"9bed9dc098d9" - } - } - ] - } -`) - - body := new(bytes.Buffer) - _, err := body.Write(bodyJSON) - require.Nil(t, err) - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("POST", "/", body), - ) - - // Now try gzipped - body.Reset() - gzWriter := gzip.NewWriter(body) - _, err = gzWriter.Write(bodyJSON) - require.Nil(t, err) - require.Nil(t, gzWriter.Close()) - - req := httptest.NewRequest("POST", "/", body) - req.Header.Add("Content-Encoding", "gzip") - - router.ServeHTTP( - httptest.NewRecorder(), - req, - ) -}