diff --git a/server/service/endpoint_scheduled_queries.go b/server/service/endpoint_scheduled_queries.go deleted file mode 100644 index 82a18d6983..0000000000 --- a/server/service/endpoint_scheduled_queries.go +++ /dev/null @@ -1,185 +0,0 @@ -package service - -import ( - "context" - - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/go-kit/kit/endpoint" -) - -//////////////////////////////////////////////////////////////////////////////// -// Get Scheduled Queries In Pack -//////////////////////////////////////////////////////////////////////////////// - -type getScheduledQueriesInPackRequest struct { - ID uint - ListOptions fleet.ListOptions -} - -type scheduledQueryResponse struct { - fleet.ScheduledQuery -} - -type getScheduledQueriesInPackResponse struct { - Scheduled []scheduledQueryResponse `json:"scheduled"` - Err error `json:"error,omitempty"` -} - -func (r getScheduledQueriesInPackResponse) error() error { return r.Err } - -func makeGetScheduledQueriesInPackEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(getScheduledQueriesInPackRequest) - resp := getScheduledQueriesInPackResponse{Scheduled: []scheduledQueryResponse{}} - - queries, err := svc.GetScheduledQueriesInPack(ctx, req.ID, req.ListOptions) - if err != nil { - return getScheduledQueriesInPackResponse{Err: err}, nil - } - - for _, q := range queries { - resp.Scheduled = append(resp.Scheduled, scheduledQueryResponse{ - ScheduledQuery: *q, - }) - } - - return resp, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Get Scheduled Query -//////////////////////////////////////////////////////////////////////////////// - -type getScheduledQueryRequest struct { - ID uint -} - -type getScheduledQueryResponse struct { - Scheduled *scheduledQueryResponse `json:"scheduled,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r getScheduledQueryResponse) error() error { return r.Err } - -func makeGetScheduledQueryEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(getScheduledQueryRequest) - - sq, err := svc.GetScheduledQuery(ctx, req.ID) - if err != nil { - return getScheduledQueryResponse{Err: err}, nil - } - - return getScheduledQueryResponse{ - Scheduled: &scheduledQueryResponse{ - ScheduledQuery: *sq, - }, - }, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Schedule Query -//////////////////////////////////////////////////////////////////////////////// - -type scheduleQueryRequest struct { - PackID uint `json:"pack_id"` - QueryID uint `json:"query_id"` - Interval uint `json:"interval"` - Snapshot *bool `json:"snapshot"` - Removed *bool `json:"removed"` - Platform *string `json:"platform"` - Version *string `json:"version"` - Shard *uint `json:"shard"` -} - -type scheduleQueryResponse struct { - Scheduled *scheduledQueryResponse `json:"scheduled,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r scheduleQueryResponse) error() error { return r.Err } - -func makeScheduleQueryEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(scheduleQueryRequest) - - scheduled, err := svc.ScheduleQuery(ctx, &fleet.ScheduledQuery{ - PackID: req.PackID, - QueryID: req.QueryID, - Interval: req.Interval, - Snapshot: req.Snapshot, - Removed: req.Removed, - Platform: req.Platform, - Version: req.Version, - Shard: req.Shard, - }) - if err != nil { - return scheduleQueryResponse{Err: err}, nil - } - return scheduleQueryResponse{Scheduled: &scheduledQueryResponse{ - ScheduledQuery: *scheduled, - }}, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Modify Scheduled Query -//////////////////////////////////////////////////////////////////////////////// - -type modifyScheduledQueryRequest struct { - ID uint - payload fleet.ScheduledQueryPayload -} - -type modifyScheduledQueryResponse struct { - Scheduled *scheduledQueryResponse `json:"scheduled,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r modifyScheduledQueryResponse) error() error { return r.Err } - -func makeModifyScheduledQueryEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(modifyScheduledQueryRequest) - - sq, err := svc.ModifyScheduledQuery(ctx, req.ID, req.payload) - if err != nil { - return modifyScheduledQueryResponse{Err: err}, nil - } - - return modifyScheduledQueryResponse{ - Scheduled: &scheduledQueryResponse{ - ScheduledQuery: *sq, - }, - }, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Delete Scheduled Query -//////////////////////////////////////////////////////////////////////////////// - -type deleteScheduledQueryRequest struct { - ID uint -} - -type deleteScheduledQueryResponse struct { - Err error `json:"error,omitempty"` -} - -func (r deleteScheduledQueryResponse) error() error { return r.Err } - -func makeDeleteScheduledQueryEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(deleteScheduledQueryRequest) - - err := svc.DeleteScheduledQuery(ctx, req.ID) - if err != nil { - return deleteScheduledQueryResponse{Err: err}, nil - } - - return deleteScheduledQueryResponse{}, nil - } -} diff --git a/server/service/handler.go b/server/service/handler.go index 1149574dc6..d645a8e6ef 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -60,11 +60,6 @@ type FleetEndpoints struct { GetQuerySpec endpoint.Endpoint CreateDistributedQueryCampaign endpoint.Endpoint CreateDistributedQueryCampaignByNames endpoint.Endpoint - GetScheduledQueriesInPack endpoint.Endpoint - ScheduleQuery endpoint.Endpoint - GetScheduledQuery endpoint.Endpoint - ModifyScheduledQuery endpoint.Endpoint - DeleteScheduledQuery endpoint.Endpoint EnrollAgent endpoint.Endpoint GetClientConfig endpoint.Endpoint GetDistributedQueries endpoint.Endpoint @@ -159,11 +154,6 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th GetQuerySpec: authenticatedUser(svc, makeGetQuerySpecEndpoint(svc)), CreateDistributedQueryCampaign: authenticatedUser(svc, makeCreateDistributedQueryCampaignEndpoint(svc)), CreateDistributedQueryCampaignByNames: authenticatedUser(svc, makeCreateDistributedQueryCampaignByNamesEndpoint(svc)), - GetScheduledQueriesInPack: authenticatedUser(svc, makeGetScheduledQueriesInPackEndpoint(svc)), - ScheduleQuery: authenticatedUser(svc, makeScheduleQueryEndpoint(svc)), - GetScheduledQuery: authenticatedUser(svc, makeGetScheduledQueryEndpoint(svc)), - ModifyScheduledQuery: authenticatedUser(svc, makeModifyScheduledQueryEndpoint(svc)), - DeleteScheduledQuery: authenticatedUser(svc, makeDeleteScheduledQueryEndpoint(svc)), CreateLabel: authenticatedUser(svc, makeCreateLabelEndpoint(svc)), ModifyLabel: authenticatedUser(svc, makeModifyLabelEndpoint(svc)), GetLabel: authenticatedUser(svc, makeGetLabelEndpoint(svc)), @@ -246,11 +236,6 @@ type fleetHandlers struct { GetQuerySpec http.Handler CreateDistributedQueryCampaign http.Handler CreateDistributedQueryCampaignByNames http.Handler - GetScheduledQueriesInPack http.Handler - ScheduleQuery http.Handler - GetScheduledQuery http.Handler - ModifyScheduledQuery http.Handler - DeleteScheduledQuery http.Handler EnrollAgent http.Handler GetClientConfig http.Handler GetDistributedQueries http.Handler @@ -332,11 +317,6 @@ func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandle GetQuerySpec: newServer(e.GetQuerySpec, decodeGetGenericSpecRequest), CreateDistributedQueryCampaign: newServer(e.CreateDistributedQueryCampaign, decodeCreateDistributedQueryCampaignRequest), CreateDistributedQueryCampaignByNames: newServer(e.CreateDistributedQueryCampaignByNames, decodeCreateDistributedQueryCampaignByNamesRequest), - GetScheduledQueriesInPack: newServer(e.GetScheduledQueriesInPack, decodeGetScheduledQueriesInPackRequest), - ScheduleQuery: newServer(e.ScheduleQuery, decodeScheduleQueryRequest), - GetScheduledQuery: newServer(e.GetScheduledQuery, decodeGetScheduledQueryRequest), - ModifyScheduledQuery: newServer(e.ModifyScheduledQuery, decodeModifyScheduledQueryRequest), - DeleteScheduledQuery: newServer(e.DeleteScheduledQuery, decodeDeleteScheduledQueryRequest), EnrollAgent: newServer(e.EnrollAgent, decodeEnrollAgentRequest), GetClientConfig: newServer(e.GetClientConfig, decodeGetClientConfigRequest), GetDistributedQueries: newServer(e.GetDistributedQueries, decodeGetDistributedQueriesRequest), @@ -519,12 +499,6 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) { r.Handle("/api/v1/fleet/queries/run", h.CreateDistributedQueryCampaign).Methods("POST").Name("create_distributed_query_campaign") r.Handle("/api/v1/fleet/queries/run_by_names", h.CreateDistributedQueryCampaignByNames).Methods("POST").Name("create_distributed_query_campaign_by_names") - r.Handle("/api/v1/fleet/packs/{id:[0-9]+}/scheduled", h.GetScheduledQueriesInPack).Methods("GET").Name("get_scheduled_queries_in_pack") - r.Handle("/api/v1/fleet/schedule", h.ScheduleQuery).Methods("POST").Name("schedule_query") - r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.GetScheduledQuery).Methods("GET").Name("get_scheduled_query") - r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.ModifyScheduledQuery).Methods("PATCH").Name("modify_scheduled_query") - r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.DeleteScheduledQuery).Methods("DELETE").Name("delete_scheduled_query") - r.Handle("/api/v1/fleet/labels", h.CreateLabel).Methods("POST").Name("create_label") r.Handle("/api/v1/fleet/labels/{id:[0-9]+}", h.ModifyLabel).Methods("PATCH").Name("modify_label") r.Handle("/api/v1/fleet/labels/{id:[0-9]+}", h.GetLabel).Methods("GET").Name("get_label") @@ -601,6 +575,12 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht e.POST("/api/v1/fleet/spec/policies", applyPolicySpecsEndpoint, applyPolicySpecsRequest{}) + e.GET("/api/v1/fleet/packs/{id:[0-9]+}/scheduled", getScheduledQueriesInPackEndpoint, getScheduledQueriesInPackRequest{}) + e.POST("/api/v1/fleet/schedule", scheduleQueryEndpoint, scheduleQueryRequest{}) + e.GET("/api/v1/fleet/schedule/{id:[0-9]+}", getScheduledQueryEndpoint, getScheduledQueryRequest{}) + e.PATCH("/api/v1/fleet/schedule/{id:[0-9]+}", modifyScheduledQueryEndpoint, modifyScheduledQueryRequest{}) + e.DELETE("/api/v1/fleet/schedule/{id:[0-9]+}", deleteScheduledQueryEndpoint, deleteScheduledQueryRequest{}) + e.GET("/api/v1/fleet/packs/{id:[0-9]+}", getPackEndpoint, getPackRequest{}) e.POST("/api/v1/fleet/packs", createPackEndpoint, createPackRequest{}) e.PATCH("/api/v1/fleet/packs/{id:[0-9]+}", modifyPackEndpoint, modifyPackRequest{}) diff --git a/server/service/handler_test.go b/server/service/handler_test.go index 31b312bb12..70eed5829b 100644 --- a/server/service/handler_test.go +++ b/server/service/handler_test.go @@ -114,21 +114,6 @@ func TestAPIRoutes(t *testing.T) { uri: "/api/v1/fleet/queries/run", }, { - verb: "GET", - uri: "/api/v1/fleet/packs/1/scheduled", - }, - { - verb: "POST", - uri: "/api/v1/fleet/schedule", - }, - { - verb: "DELETE", - uri: "/api/v1/fleet/schedule/1", - }, - { - verb: "PATCH", - uri: "/api/v1/fleet/schedule/1", - }, { verb: "POST", uri: "/api/v1/osquery/enroll", }, diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index d028408ff3..98be569925 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -1437,3 +1437,106 @@ func (s *integrationTestSuite) TestHostsAddToTeam() { ids := []uint{listResp.Hosts[0].ID, listResp.Hosts[1].ID} require.ElementsMatch(t, ids, []uint{hosts[1].ID, hosts[2].ID}) } + +func (s *integrationTestSuite) TestScheduledQueries() { + t := s.T() + + // create a pack + var createPackResp createPackResponse + reqPack := &createPackRequest{ + PackPayload: fleet.PackPayload{ + Name: ptr.String(strings.ReplaceAll(t.Name(), "/", "_")), + }, + } + s.DoJSON("POST", "/api/v1/fleet/packs", reqPack, http.StatusOK, &createPackResp) + pack := createPackResp.Pack.Pack + + // create a query + var createQueryResp createQueryResponse + reqQuery := &fleet.QueryPayload{ + Name: ptr.String(t.Name()), + Query: ptr.String("select * from time;"), + } + s.DoJSON("POST", "/api/v1/fleet/queries", reqQuery, http.StatusOK, &createQueryResp) + query := createQueryResp.Query + + // list scheduled queries in pack, none yet + var getInPackResp getScheduledQueriesInPackResponse + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/packs/%d/scheduled", pack.ID), nil, http.StatusOK, &getInPackResp) + assert.Len(t, getInPackResp.Scheduled, 0) + + // list scheduled queries in non-existing pack + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/packs/%d/scheduled", pack.ID+1), nil, http.StatusOK, &getInPackResp) + assert.Len(t, getInPackResp.Scheduled, 0) + + // create scheduled query + var createResp scheduleQueryResponse + reqSQ := &scheduleQueryRequest{ + PackID: pack.ID, + QueryID: query.ID, + Interval: 1, + } + s.DoJSON("POST", "/api/v1/fleet/schedule", reqSQ, http.StatusOK, &createResp) + sq1 := createResp.Scheduled.ScheduledQuery + assert.NotZero(t, sq1.ID) + assert.Equal(t, uint(1), sq1.Interval) + + // create scheduled query with invalid pack + reqSQ = &scheduleQueryRequest{ + PackID: pack.ID + 1, + QueryID: query.ID, + Interval: 2, + } + s.DoJSON("POST", "/api/v1/fleet/schedule", reqSQ, http.StatusUnprocessableEntity, &createResp) + + // create scheduled query with invalid query + reqSQ = &scheduleQueryRequest{ + PackID: pack.ID, + QueryID: query.ID + 1, + Interval: 3, + } + s.DoJSON("POST", "/api/v1/fleet/schedule", reqSQ, http.StatusInternalServerError, &createResp) + + // list scheduled queries in pack + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/packs/%d/scheduled", pack.ID), nil, http.StatusOK, &getInPackResp) + require.Len(t, getInPackResp.Scheduled, 1) + assert.Equal(t, sq1.ID, getInPackResp.Scheduled[0].ID) + + // list scheduled queries in pack, next page + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/packs/%d/scheduled", pack.ID), nil, http.StatusOK, &getInPackResp, "page", "1", "per_page", "2") + require.Len(t, getInPackResp.Scheduled, 0) + + // get non-existing scheduled query + var getResp getScheduledQueryResponse + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/schedule/%d", sq1.ID+1), nil, http.StatusNotFound, &getResp) + + // get existing scheduled query + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/schedule/%d", sq1.ID), nil, http.StatusOK, &getResp) + assert.Equal(t, sq1.ID, getResp.Scheduled.ID) + assert.Equal(t, sq1.Interval, getResp.Scheduled.Interval) + + // modify scheduled query + var modResp modifyScheduledQueryResponse + reqMod := fleet.ScheduledQueryPayload{ + Interval: ptr.Uint(4), + } + s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/schedule/%d", sq1.ID), reqMod, http.StatusOK, &modResp) + assert.Equal(t, sq1.ID, modResp.Scheduled.ID) + assert.Equal(t, uint(4), modResp.Scheduled.Interval) + + // modify non-existing scheduled query + reqMod = fleet.ScheduledQueryPayload{ + Interval: ptr.Uint(5), + } + s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/schedule/%d", sq1.ID+1), reqMod, http.StatusNotFound, &modResp) + + // delete non-existing scheduled query + var delResp deleteScheduledQueryResponse + s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/schedule/%d", sq1.ID+1), nil, http.StatusNotFound, &delResp) + + // delete existing scheduled query + s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/schedule/%d", sq1.ID), nil, http.StatusOK, &delResp) + + // get the now-deleted scheduled query + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/schedule/%d", sq1.ID), nil, http.StatusNotFound, &getResp) +} diff --git a/server/service/scheduled_queries.go b/server/service/scheduled_queries.go new file mode 100644 index 0000000000..bd0f64e519 --- /dev/null +++ b/server/service/scheduled_queries.go @@ -0,0 +1,304 @@ +package service + +import ( + "context" + + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" + "github.com/fleetdm/fleet/v4/server/fleet" +) + +//////////////////////////////////////////////////////////////////////////////// +// Get Scheduled Queries In Pack +//////////////////////////////////////////////////////////////////////////////// + +type getScheduledQueriesInPackRequest struct { + ID uint `url:"id"` + // TODO(mna): was not set in the old pattern + ListOptions fleet.ListOptions `url:"list_options"` +} + +type scheduledQueryResponse struct { + fleet.ScheduledQuery +} + +type getScheduledQueriesInPackResponse struct { + Scheduled []scheduledQueryResponse `json:"scheduled"` + Err error `json:"error,omitempty"` +} + +func (r getScheduledQueriesInPackResponse) error() error { return r.Err } + +func getScheduledQueriesInPackEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*getScheduledQueriesInPackRequest) + resp := getScheduledQueriesInPackResponse{Scheduled: []scheduledQueryResponse{}} + + queries, err := svc.GetScheduledQueriesInPack(ctx, req.ID, req.ListOptions) + if err != nil { + return getScheduledQueriesInPackResponse{Err: err}, nil + } + + for _, q := range queries { + resp.Scheduled = append(resp.Scheduled, scheduledQueryResponse{ + ScheduledQuery: *q, + }) + } + + return resp, nil +} + +func (svc *Service) GetScheduledQueriesInPack(ctx context.Context, id uint, opts fleet.ListOptions) ([]*fleet.ScheduledQuery, error) { + // Scheduled queries are currently authorized the same as packs. + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { + return nil, err + } + + return svc.ds.ListScheduledQueriesInPack(ctx, id, opts) +} + +//////////////////////////////////////////////////////////////////////////////// +// Schedule Query +//////////////////////////////////////////////////////////////////////////////// + +type scheduleQueryRequest struct { + PackID uint `json:"pack_id"` + QueryID uint `json:"query_id"` + Interval uint `json:"interval"` + Snapshot *bool `json:"snapshot"` + Removed *bool `json:"removed"` + Platform *string `json:"platform"` + Version *string `json:"version"` + Shard *uint `json:"shard"` +} + +type scheduleQueryResponse struct { + Scheduled *scheduledQueryResponse `json:"scheduled,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r scheduleQueryResponse) error() error { return r.Err } + +func scheduleQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*scheduleQueryRequest) + + scheduled, err := svc.ScheduleQuery(ctx, &fleet.ScheduledQuery{ + PackID: req.PackID, + QueryID: req.QueryID, + Interval: req.Interval, + Snapshot: req.Snapshot, + Removed: req.Removed, + Platform: req.Platform, + Version: req.Version, + Shard: req.Shard, + }) + if err != nil { + return scheduleQueryResponse{Err: err}, nil + } + return scheduleQueryResponse{Scheduled: &scheduledQueryResponse{ + ScheduledQuery: *scheduled, + }}, nil +} + +func (svc *Service) ScheduleQuery(ctx context.Context, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) { + // Scheduled queries are currently authorized the same as packs. + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { + return nil, err + } + + return svc.unauthorizedScheduleQuery(ctx, sq) +} + +func (svc *Service) unauthorizedScheduleQuery(ctx context.Context, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) { + // Fill in the name with query name if it is unset (because the UI + // doesn't provide a way to set it) + if sq.Name == "" { + query, err := svc.ds.Query(ctx, sq.QueryID) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "lookup name for query") + } + + packQueries, err := svc.ds.ListScheduledQueriesInPack(ctx, sq.PackID, fleet.ListOptions{}) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "find existing scheduled queries") + } + _ = packQueries + + sq.Name = findNextNameForQuery(query.Name, packQueries) + sq.QueryName = query.Name + } else if sq.QueryName == "" { + query, err := svc.ds.Query(ctx, sq.QueryID) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "lookup name for query") + } + sq.QueryName = query.Name + } + return svc.ds.NewScheduledQuery(ctx, sq) +} + +// Add "-1" suffixes to the query name until it is unique +func findNextNameForQuery(name string, scheduled []*fleet.ScheduledQuery) string { + for _, q := range scheduled { + if name == q.Name { + return findNextNameForQuery(name+"-1", scheduled) + } + } + return name +} + +//////////////////////////////////////////////////////////////////////////////// +// Get Scheduled Query +//////////////////////////////////////////////////////////////////////////////// + +type getScheduledQueryRequest struct { + ID uint `url:"id"` +} + +type getScheduledQueryResponse struct { + Scheduled *scheduledQueryResponse `json:"scheduled,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r getScheduledQueryResponse) error() error { return r.Err } + +func getScheduledQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*getScheduledQueryRequest) + + sq, err := svc.GetScheduledQuery(ctx, req.ID) + if err != nil { + return getScheduledQueryResponse{Err: err}, nil + } + + return getScheduledQueryResponse{ + Scheduled: &scheduledQueryResponse{ + ScheduledQuery: *sq, + }, + }, nil +} + +func (svc *Service) GetScheduledQuery(ctx context.Context, id uint) (*fleet.ScheduledQuery, error) { + // Scheduled queries are currently authorized the same as packs. + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { + return nil, err + } + + return svc.ds.ScheduledQuery(ctx, id) +} + +//////////////////////////////////////////////////////////////////////////////// +// Modify Scheduled Query +//////////////////////////////////////////////////////////////////////////////// + +type modifyScheduledQueryRequest struct { + ID uint `json:"-" url:"id"` + fleet.ScheduledQueryPayload +} + +type modifyScheduledQueryResponse struct { + Scheduled *scheduledQueryResponse `json:"scheduled,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r modifyScheduledQueryResponse) error() error { return r.Err } + +func modifyScheduledQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*modifyScheduledQueryRequest) + + sq, err := svc.ModifyScheduledQuery(ctx, req.ID, req.ScheduledQueryPayload) + if err != nil { + return modifyScheduledQueryResponse{Err: err}, nil + } + + return modifyScheduledQueryResponse{ + Scheduled: &scheduledQueryResponse{ + ScheduledQuery: *sq, + }, + }, nil +} + +func (svc *Service) ModifyScheduledQuery(ctx context.Context, id uint, p fleet.ScheduledQueryPayload) (*fleet.ScheduledQuery, error) { + // Scheduled queries are currently authorized the same as packs. + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { + return nil, err + } + + return svc.unauthorizedModifyScheduledQuery(ctx, id, p) +} + +func (svc *Service) unauthorizedModifyScheduledQuery(ctx context.Context, id uint, p fleet.ScheduledQueryPayload) (*fleet.ScheduledQuery, error) { + sq, err := svc.ds.ScheduledQuery(ctx, id) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "getting scheduled query to modify") + } + + if p.PackID != nil { + sq.PackID = *p.PackID + } + + if p.QueryID != nil { + sq.QueryID = *p.QueryID + } + + if p.Interval != nil { + sq.Interval = *p.Interval + } + + if p.Snapshot != nil { + sq.Snapshot = p.Snapshot + } + + if p.Removed != nil { + sq.Removed = p.Removed + } + + if p.Platform != nil { + sq.Platform = p.Platform + } + + if p.Version != nil { + sq.Version = p.Version + } + + if p.Shard != nil { + if p.Shard.Valid { + val := uint(p.Shard.Int64) + sq.Shard = &val + } else { + sq.Shard = nil + } + } + + return svc.ds.SaveScheduledQuery(ctx, sq) +} + +//////////////////////////////////////////////////////////////////////////////// +// Delete Scheduled Query +//////////////////////////////////////////////////////////////////////////////// + +type deleteScheduledQueryRequest struct { + ID uint `url:"id"` +} + +type deleteScheduledQueryResponse struct { + Err error `json:"error,omitempty"` +} + +func (r deleteScheduledQueryResponse) error() error { return r.Err } + +func deleteScheduledQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*deleteScheduledQueryRequest) + + err := svc.DeleteScheduledQuery(ctx, req.ID) + if err != nil { + return deleteScheduledQueryResponse{Err: err}, nil + } + + return deleteScheduledQueryResponse{}, nil +} + +func (svc *Service) DeleteScheduledQuery(ctx context.Context, id uint) error { + // Scheduled queries are currently authorized the same as packs. + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { + return err + } + + return svc.ds.DeleteScheduledQuery(ctx, id) +} diff --git a/server/service/service_scheduled_queries_test.go b/server/service/scheduled_queries_test.go similarity index 58% rename from server/service/service_scheduled_queries_test.go rename to server/service/scheduled_queries_test.go index 1f96c101f9..d22d7e31fb 100644 --- a/server/service/service_scheduled_queries_test.go +++ b/server/service/scheduled_queries_test.go @@ -4,13 +4,103 @@ import ( "context" "testing" + "github.com/fleetdm/fleet/v4/server/contexts/viewer" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mock" + "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestScheduledQueriesAuth(t *testing.T) { + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + + ds.ListScheduledQueriesInPackFunc = func(ctx context.Context, id uint, opts fleet.ListOptions) ([]*fleet.ScheduledQuery, error) { + return nil, nil + } + ds.NewScheduledQueryFunc = func(ctx context.Context, sq *fleet.ScheduledQuery, opts ...fleet.OptionalArg) (*fleet.ScheduledQuery, error) { + return sq, nil + } + ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) { + return &fleet.Query{}, nil + } + ds.ScheduledQueryFunc = func(ctx context.Context, id uint) (*fleet.ScheduledQuery, error) { + return &fleet.ScheduledQuery{}, nil + } + ds.SaveScheduledQueryFunc = func(ctx context.Context, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) { + return sq, nil + } + ds.DeleteScheduledQueryFunc = func(ctx context.Context, id uint) error { + return nil + } + + var testCases = []struct { + name string + user *fleet.User + shouldFailWrite bool + shouldFailRead bool + }{ + { + "global admin", + &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}, + false, + false, + }, + { + "global maintainer", + &fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)}, + false, + false, + }, + { + "global observer", + &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)}, + true, + true, + }, + { + "team admin", + &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}}, + true, + false, + }, + { + "team maintainer", + &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}}, + true, + false, + }, + { + "team observer", + &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}}, + true, + true, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: tt.user}) + + _, err := svc.GetScheduledQueriesInPack(ctx, 1, fleet.ListOptions{}) + checkAuthErr(t, tt.shouldFailRead, err) + + _, err = svc.ScheduleQuery(ctx, &fleet.ScheduledQuery{}) + checkAuthErr(t, tt.shouldFailWrite, err) + + _, err = svc.GetScheduledQuery(ctx, 1) + checkAuthErr(t, tt.shouldFailRead, err) + + _, err = svc.ModifyScheduledQuery(ctx, 1, fleet.ScheduledQueryPayload{}) + checkAuthErr(t, tt.shouldFailWrite, err) + + err = svc.DeleteScheduledQuery(ctx, 1) + checkAuthErr(t, tt.shouldFailWrite, err) + }) + } +} + func TestScheduleQuery(t *testing.T) { ds := new(mock.Store) svc := newTestService(ds, nil, nil) diff --git a/server/service/service_scheduled_queries.go b/server/service/service_scheduled_queries.go deleted file mode 100644 index 1efe0d9b7b..0000000000 --- a/server/service/service_scheduled_queries.go +++ /dev/null @@ -1,133 +0,0 @@ -package service - -import ( - "context" - - "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - "github.com/fleetdm/fleet/v4/server/fleet" -) - -// Scheduled queries are currently authorized the same as packs. - -func (svc *Service) GetScheduledQueriesInPack(ctx context.Context, id uint, opts fleet.ListOptions) ([]*fleet.ScheduledQuery, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { - return nil, err - } - - return svc.ds.ListScheduledQueriesInPack(ctx, id, opts) -} - -func (svc *Service) GetScheduledQuery(ctx context.Context, id uint) (*fleet.ScheduledQuery, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { - return nil, err - } - - return svc.ds.ScheduledQuery(ctx, id) -} - -func (svc *Service) ScheduleQuery(ctx context.Context, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { - return nil, err - } - - return svc.unauthorizedScheduleQuery(ctx, sq) -} - -func (svc *Service) unauthorizedScheduleQuery(ctx context.Context, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) { - // Fill in the name with query name if it is unset (because the UI - // doesn't provide a way to set it) - if sq.Name == "" { - query, err := svc.ds.Query(ctx, sq.QueryID) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "lookup name for query") - } - - packQueries, err := svc.ds.ListScheduledQueriesInPack(ctx, sq.PackID, fleet.ListOptions{}) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "find existing scheduled queries") - } - _ = packQueries - - sq.Name = findNextNameForQuery(query.Name, packQueries) - sq.QueryName = query.Name - } else if sq.QueryName == "" { - query, err := svc.ds.Query(ctx, sq.QueryID) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "lookup name for query") - } - sq.QueryName = query.Name - } - return svc.ds.NewScheduledQuery(ctx, sq) -} - -// Add "-1" suffixes to the query name until it is unique -func findNextNameForQuery(name string, scheduled []*fleet.ScheduledQuery) string { - for _, q := range scheduled { - if name == q.Name { - return findNextNameForQuery(name+"-1", scheduled) - } - } - return name -} - -func (svc *Service) ModifyScheduledQuery(ctx context.Context, id uint, p fleet.ScheduledQueryPayload) (*fleet.ScheduledQuery, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { - return nil, err - } - - return svc.unauthorizedModifyScheduledQuery(ctx, id, p) -} - -func (svc *Service) unauthorizedModifyScheduledQuery(ctx context.Context, id uint, p fleet.ScheduledQueryPayload) (*fleet.ScheduledQuery, error) { - sq, err := svc.ds.ScheduledQuery(ctx, id) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "getting scheduled query to modify") - } - - if p.PackID != nil { - sq.PackID = *p.PackID - } - - if p.QueryID != nil { - sq.QueryID = *p.QueryID - } - - if p.Interval != nil { - sq.Interval = *p.Interval - } - - if p.Snapshot != nil { - sq.Snapshot = p.Snapshot - } - - if p.Removed != nil { - sq.Removed = p.Removed - } - - if p.Platform != nil { - sq.Platform = p.Platform - } - - if p.Version != nil { - sq.Version = p.Version - } - - if p.Shard != nil { - if p.Shard.Valid { - val := uint(p.Shard.Int64) - sq.Shard = &val - } else { - sq.Shard = nil - } - } - - return svc.ds.SaveScheduledQuery(ctx, sq) -} - -func (svc *Service) DeleteScheduledQuery(ctx context.Context, id uint) error { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { - return err - } - - return svc.ds.DeleteScheduledQuery(ctx, id) -} diff --git a/server/service/transport_scheduled_queries.go b/server/service/transport_scheduled_queries.go deleted file mode 100644 index 6a7bbbda0d..0000000000 --- a/server/service/transport_scheduled_queries.go +++ /dev/null @@ -1,62 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "net/http" -) - -func decodeGetScheduledQueriesInPackRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := uintFromRequest(r, "id") - if err != nil { - return nil, err - } - var req getScheduledQueriesInPackRequest - req.ID = uint(id) - return req, nil -} - -func decodeScheduleQueryRequest(ctx context.Context, r *http.Request) (interface{}, error) { - var req scheduleQueryRequest - - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, err - } - - return req, nil -} - -func decodeModifyScheduledQueryRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := uintFromRequest(r, "id") - if err != nil { - return nil, err - } - var req modifyScheduledQueryRequest - - if err := json.NewDecoder(r.Body).Decode(&req.payload); err != nil { - return nil, err - } - - req.ID = uint(id) - return req, nil -} - -func decodeDeleteScheduledQueryRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := uintFromRequest(r, "id") - if err != nil { - return nil, err - } - var req deleteScheduledQueryRequest - req.ID = uint(id) - return req, nil -} - -func decodeGetScheduledQueryRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := uintFromRequest(r, "id") - if err != nil { - return nil, err - } - var req getScheduledQueryRequest - req.ID = uint(id) - return req, nil -} diff --git a/server/service/transport_scheduled_queries_test.go b/server/service/transport_scheduled_queries_test.go index e48c4f647c..dcbcfd1723 100644 --- a/server/service/transport_scheduled_queries_test.go +++ b/server/service/transport_scheduled_queries_test.go @@ -1,17 +1,8 @@ package service -import ( - "bytes" - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gorilla/mux" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) +// TODO(mna): delete after covering those in integration tests +/* func TestDecodeGetScheduledQueriesInPackRequest(t *testing.T) { router := mux.NewRouter() router.HandleFunc("/api/v1/fleet/packs/{id}/scheduled", func(writer http.ResponseWriter, request *http.Request) { @@ -101,19 +92,4 @@ func TestDecodeDeleteScheduledQueryRequest(t *testing.T) { httptest.NewRequest("DELETE", "/api/v1/fleet/scheduled/1", nil), ) } - -func TestDecodeGetScheduledQueryRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/api/v1/fleet/scheduled/{id}", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeGetScheduledQueryRequest(context.Background(), request) - assert.Nil(t, err) - - params := r.(getScheduledQueryRequest) - assert.Equal(t, uint(1), params.ID) - }).Methods("GET") - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("GET", "/api/v1/fleet/scheduled/1", nil), - ) -} +*/