diff --git a/changes/issue-1498-team-schedules b/changes/issue-1498-team-schedules new file mode 100644 index 0000000000..5e28f3b8a1 --- /dev/null +++ b/changes/issue-1498-team-schedules @@ -0,0 +1 @@ +* Add team level schedules diff --git a/server/datastore/mysql/packs.go b/server/datastore/mysql/packs.go index 6a9825a628..9b1d11c610 100644 --- a/server/datastore/mysql/packs.go +++ b/server/datastore/mysql/packs.go @@ -404,7 +404,7 @@ func (d *Datastore) EnsureGlobalPack() (*fleet.Pack, error) { func (d *Datastore) insertNewGlobalPack() (*fleet.Pack, error) { var packID uint - d.withTx(func(tx *sqlx.Tx) error { + err := d.withTx(func(tx *sqlx.Tx) error { res, err := tx.Exec( `INSERT INTO packs (name, description, platform, pack_type) VALUES ('Global', 'Global pack', '','global')`, ) @@ -424,6 +424,63 @@ func (d *Datastore) insertNewGlobalPack() (*fleet.Pack, error) { } return nil }) + if err != nil { + return nil, err + } + + return d.Pack(packID) +} + +func (d *Datastore) EnsureTeamPack(teamID uint) (*fleet.Pack, error) { + pack := &fleet.Pack{} + t, err := d.Team(teamID) + if err != nil || t == nil { + return nil, errors.Wrap(err, "Error finding team") + } + + teamType := fmt.Sprintf("team-%d", teamID) + err = d.db.Get(pack, `SELECT * FROM packs WHERE pack_type = ?`, teamType) + if err == sql.ErrNoRows { + return d.insertNewTeamPack(teamID) + } else if err != nil { + return nil, errors.Wrap(err, "get pack") + } + + if err := d.loadPackTargets(pack); err != nil { + return nil, err + } + + return pack, nil +} + +func (d *Datastore) insertNewTeamPack(teamID uint) (*fleet.Pack, error) { + var packID uint + teamType := fmt.Sprintf("team-%d", teamID) + err := d.withTx(func(tx *sqlx.Tx) error { + res, err := tx.Exec( + `INSERT INTO packs (name, description, platform, pack_type) + VALUES (?, 'Schedule additional queries for all hosts assigned to this team.', '',?)`, + teamType, teamType, + ) + if err != nil { + return err + } + packId, err := res.LastInsertId() + if err != nil { + return err + } + packID = uint(packId) + if _, err := tx.Exec( + `INSERT INTO pack_targets (pack_id, type, target_id) VALUES (?, ?, ?)`, + packID, fleet.TargetTeam, teamID, + ); err != nil { + return errors.Wrap(err, "adding team id target to pack") + } + return nil + }) + if err != nil { + return nil, err + } return d.Pack(packID) } diff --git a/server/datastore/mysql/packs_test.go b/server/datastore/mysql/packs_test.go index 7f9d91c31b..806b886537 100644 --- a/server/datastore/mysql/packs_test.go +++ b/server/datastore/mysql/packs_test.go @@ -1,6 +1,7 @@ package mysql import ( + "fmt" "testing" "github.com/WatchBeam/clock" @@ -444,3 +445,51 @@ func TestEnsureGlobalPack(t *testing.T) { assert.Equal(t, gp.ID, packs[0].ID) assert.Equal(t, "global", *gp.Type) } + +func TestEnsureTeamPack(t *testing.T) { + ds := CreateMySQLDS(t) + defer ds.Close() + + packs, err := ds.ListPacks(fleet.ListOptions{}) + require.Nil(t, err) + assert.Len(t, packs, 0) + + _, err = ds.EnsureTeamPack(12) + require.Error(t, err) + + team1, err := ds.NewTeam(&fleet.Team{Name: "team1"}) + require.NoError(t, err) + + tp, err := ds.EnsureTeamPack(team1.ID) + require.NoError(t, err) + + packs, err = ds.ListPacks(fleet.ListOptions{}) + require.Nil(t, err) + assert.Len(t, packs, 1) + assert.Equal(t, tp.ID, packs[0].ID) + assert.Equal(t, fmt.Sprintf("team-%d", team1.ID), *tp.Type) + assert.Equal(t, []uint{team1.ID}, tp.TeamIDs) + + _, err = ds.EnsureTeamPack(team1.ID) + require.NoError(t, err) + + packs, err = ds.ListPacks(fleet.ListOptions{}) + require.Nil(t, err) + assert.Len(t, packs, 1) + assert.Equal(t, tp.ID, packs[0].ID) + + team2, err := ds.NewTeam(&fleet.Team{Name: "team2"}) + require.NoError(t, err) + + tp2, err := ds.EnsureTeamPack(team2.ID) + require.NoError(t, err) + + packs, err = ds.ListPacks(fleet.ListOptions{}) + require.Nil(t, err) + assert.Len(t, packs, 2) + assert.Equal(t, tp.ID, packs[0].ID) + assert.Equal(t, tp2.ID, packs[1].ID) + + assert.Equal(t, fmt.Sprintf("team-%d", team2.ID), *tp2.Type) + assert.Equal(t, []uint{team2.ID}, tp2.TeamIDs) +} diff --git a/server/datastore/mysql/scheduled_queries.go b/server/datastore/mysql/scheduled_queries.go index cdaa645ff0..05c5d356ce 100644 --- a/server/datastore/mysql/scheduled_queries.go +++ b/server/datastore/mysql/scheduled_queries.go @@ -57,6 +57,7 @@ func (d *Datastore) insertScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) query := ` INSERT INTO scheduled_queries ( query_name, + query_id, name, pack_id, snapshot, @@ -67,11 +68,11 @@ func (d *Datastore) insertScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) shard, denylist ) - SELECT name, ?, ?, ?, ?, ?, ?, ?, ?, ? + SELECT name, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? FROM queries WHERE id = ? ` - result, err := execFunc(query, sq.Name, sq.PackID, sq.Snapshot, sq.Removed, sq.Interval, sq.Platform, sq.Version, sq.Shard, sq.Denylist, sq.QueryID) + result, err := execFunc(query, sq.QueryID, sq.Name, sq.PackID, sq.Snapshot, sq.Removed, sq.Interval, sq.Platform, sq.Version, sq.Shard, sq.Denylist, sq.QueryID) if err != nil { return nil, errors.Wrap(err, "insert scheduled query") } diff --git a/server/fleet/packs.go b/server/fleet/packs.go index 5b9db69642..f3c6f11b84 100644 --- a/server/fleet/packs.go +++ b/server/fleet/packs.go @@ -38,6 +38,9 @@ type PackStore interface { // EnsureGlobalPack gets or inserts a pack with type global EnsureGlobalPack() (*Pack, error) + + // EnsureTeamPack gets or inserts a pack with type global + EnsureTeamPack(teamID uint) (*Pack, error) } // PackService is the service interface for managing query packs. diff --git a/server/fleet/service.go b/server/fleet/service.go index d9ebfc71c8..6bcbe2e7dd 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -22,4 +22,5 @@ type Service interface { UserRolesService GlobalScheduleService TranslatorService + TeamScheduleService } diff --git a/server/fleet/team_schedule.go b/server/fleet/team_schedule.go new file mode 100644 index 0000000000..62ea20a2ea --- /dev/null +++ b/server/fleet/team_schedule.go @@ -0,0 +1,10 @@ +package fleet + +import "context" + +type TeamScheduleService interface { + TeamScheduleQuery(ctx context.Context, teamID uint, sq *ScheduledQuery) (*ScheduledQuery, error) + GetTeamScheduledQueries(ctx context.Context, teamID uint, opts ListOptions) ([]*ScheduledQuery, error) + ModifyTeamScheduledQueries(ctx context.Context, teamID uint, scheduledQueryID uint, q ScheduledQueryPayload) (*ScheduledQuery, error) + DeleteTeamScheduledQueries(ctx context.Context, teamID uint, id uint) error +} diff --git a/server/mock/datastore_packs.go b/server/mock/datastore_packs.go index 0464a30e64..2e75586b44 100644 --- a/server/mock/datastore_packs.go +++ b/server/mock/datastore_packs.go @@ -42,6 +42,8 @@ type ListExplicitHostsInPackFunc func(pid uint, opt fleet.ListOptions) ([]uint, type EnsureGlobalPackFunc func() (*fleet.Pack, error) +type EnsureTeamPackFunc func(teamID uint) (*fleet.Pack, error) + type PackStore struct { ApplyPackSpecsFunc ApplyPackSpecsFunc ApplyPackSpecsFuncInvoked bool @@ -96,10 +98,19 @@ type PackStore struct { EnsureGlobalPackFunc EnsureGlobalPackFunc EnsureGlobalPackFuncInvoked bool + + EnsureTeamPackFunc EnsureTeamPackFunc + EnsureTeamPackFuncInvoked bool } func (s *PackStore) EnsureGlobalPack() (*fleet.Pack, error) { - panic("implement me") + s.EnsureGlobalPackFuncInvoked = true + return s.EnsureGlobalPackFunc() +} + +func (s *PackStore) EnsureTeamPack(teamID uint) (*fleet.Pack, error) { + s.EnsureTeamPackFuncInvoked = true + return s.EnsureTeamPackFunc(teamID) } func (s *PackStore) ApplyPackSpecs(specs []*fleet.PackSpec) error { diff --git a/server/service/endpoint_utils.go b/server/service/endpoint_utils.go index 607d33e591..0730893fbb 100644 --- a/server/service/endpoint_utils.go +++ b/server/service/endpoint_utils.go @@ -8,6 +8,7 @@ import ( "github.com/fleetdm/fleet/v4/server/fleet" "github.com/go-kit/kit/endpoint" + "github.com/pkg/errors" ) type handlerFunc func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) @@ -23,6 +24,122 @@ func makeDecoderForType(v interface{}) func(ctx context.Context, r *http.Request } } +func makeDecoderForIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) { + t := reflect.TypeOf(v) + return func(ctx context.Context, r *http.Request) (interface{}, error) { + value := reflect.New(t) + for _, idKey := range idKeys { + err := setIDFromKey(r, t, value, idKey) + if err != nil { + return nil, err + } + } + + return value.Interface(), nil + } +} + +func makeDecoderForTypeAndIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) { + t := reflect.TypeOf(v) + return func(ctx context.Context, r *http.Request) (interface{}, error) { + req, err := makeDecoderForType(v)(ctx, r) + if err != nil { + return nil, err + } + + value := reflect.ValueOf(req) + for _, idKey := range idKeys { + err := setIDFromKey(r, t, value, idKey) + if err != nil { + return nil, err + } + } + + return req, nil + } +} + +func setIDFromKey(r *http.Request, t reflect.Type, v reflect.Value, idKey string) error { + id, err := idFromRequest(r, idKey) + if err != nil { + return err + } + name := "" + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.Tag.Get("url") == idKey { + name = f.Name + } + } + if name == "" { + return errors.Errorf("%s not found in URL", idKey) + } + + field := v.Elem().FieldByName(name) + field.SetUint(uint64(id)) + + return nil +} + +func makeDecoderForOptionsAndIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) { + t := reflect.TypeOf(v) + return func(ctx context.Context, r *http.Request) (interface{}, error) { + req, err := makeDecoderForIDs(v, idKeys...)(ctx, r) + if err != nil { + return nil, err + } + + value := reflect.ValueOf(req) + err = setListOptions(r, t, value) + if err != nil { + return nil, err + } + + return req, nil + } +} + +func makeDecoderForTypeOptionsAndIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) { + t := reflect.TypeOf(v) + return func(ctx context.Context, r *http.Request) (interface{}, error) { + req, err := makeDecoderForTypeAndIDs(v, idKeys...)(ctx, r) + if err != nil { + return nil, err + } + + value := reflect.ValueOf(req) + err = setListOptions(r, t, value) + if err != nil { + return nil, err + } + + return req, nil + } +} + +func setListOptions(r *http.Request, t reflect.Type, v reflect.Value) error { + opt, err := listOptionsFromRequest(r) + if err != nil { + return err + } + name := "" + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.Tag.Get("url") == "list_options" { + name = f.Name + } + } + // ListOptions are optional + if name == "" { + return nil + } + + field := v.Elem().FieldByName(name) + field.Set(reflect.ValueOf(opt)) + + return nil +} + func makeAuthenticatedServiceEndpoint(svc fleet.Service, f handlerFunc) endpoint.Endpoint { return authenticatedUser(svc, makeServiceEndpoint(svc, f)) } diff --git a/server/service/handler.go b/server/service/handler.go index bd723639e3..3ea1ed76d0 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -682,6 +682,11 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht handle("POST", "/api/v1/fleet/users/roles/spec", makeApplyUserRoleSpecsEndpoint(svc, opts), "apply_user_roles_spec", r) handle("POST", "/api/v1/fleet/translate", makeTranslatorEndpoint(svc, opts), "translator", r) handle("POST", "/api/v1/fleet/spec/teams", makeApplyTeamSpecsEndpoint(svc, opts), "apply_team_specs", r) + + handle("GET", "/api/v1/fleet/team/{team_id}/schedule", makeGetTeamScheduleEndpoint(svc, opts), "get_team_schedule", r) + handle("POST", "/api/v1/fleet/team/{team_id}/schedule", makeTeamScheduleQueryEndpoint(svc, opts), "add_to_team_schedule", r) + handle("PATCH", "/api/v1/fleet/team/{team_id}/schedule/{scheduled_query_id}", makeModifyTeamScheduleEndpoint(svc, opts), "edit_team_schedule", r) + handle("DELETE", "/api/v1/fleet/team/{team_id}/schedule/{scheduled_query_id}", makeDeleteTeamScheduleEndpoint(svc, opts), "delete_team_schedule", r) } func handle(verb, path string, handler http.Handler, name string, r *mux.Router) { diff --git a/server/service/integration_test.go b/server/service/integration_test.go index 08344c7d9a..143a61bf9c 100644 --- a/server/service/integration_test.go +++ b/server/service/integration_test.go @@ -157,6 +157,25 @@ func doReq( return resp } +func doRawReq( + t *testing.T, + body []byte, + method string, + server *httptest.Server, + path string, + token string, + expectedStatusCode int, +) *http.Response { + requestBody := &nopCloser{bytes.NewBuffer(body)} + req, _ := http.NewRequest(method, server.URL+path, requestBody) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) + client := &http.Client{} + resp, err := client.Do(req) + require.Nil(t, err) + assert.Equal(t, expectedStatusCode, resp.StatusCode) + return resp +} + func doJSONReq( t *testing.T, params interface{}, @@ -338,11 +357,7 @@ func TestGlobalSchedule(t *testing.T) { require.NoError(t, err) gsParams := fleet.ScheduledQueryPayload{QueryID: ptr.Uint(qr.ID), Interval: ptr.Uint(42)} - type responseType struct { - Scheduled *fleet.ScheduledQuery `json:"scheduled,omitempty"` - Err error `json:"error,omitempty"` - } - r := responseType{} + r := globalScheduleQueryResponse{} doJSONReq(t, gsParams, "POST", server, "/api/v1/fleet/global/schedule", token, http.StatusOK, &r) require.Nil(t, r.Err) @@ -366,7 +381,7 @@ func TestGlobalSchedule(t *testing.T) { require.Len(t, gs.GlobalSchedule, 1) assert.Equal(t, uint(55), gs.GlobalSchedule[0].Interval) - r = responseType{} + r = globalScheduleQueryResponse{} doJSONReq( t, nil, "DELETE", server, fmt.Sprintf("/api/v1/fleet/global/schedule/%d", id), @@ -458,6 +473,73 @@ func TestTranslator(t *testing.T) { assert.Equal(t, users[payload.List[0].Payload.Identifier].ID, payload.List[0].Payload.ID) } +func TestTeamSchedule(t *testing.T) { + ds := mysql.CreateMySQLDS(t) + defer ds.Close() + + test.AddAllHostsLabel(t, ds) + + _, server := RunServerForTestsWithDS(t, ds) + token := getTestAdminToken(t, server) + + team1, err := ds.NewTeam(&fleet.Team{ + ID: 42, + Name: "team1", + Description: "desc team1", + }) + require.NoError(t, err) + + ts := getTeamScheduleResponse{} + doJSONReq(t, nil, "GET", server, fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), token, http.StatusOK, &ts) + assert.Len(t, ts.Scheduled, 0) + + qr, err := ds.NewQuery(&fleet.Query{Name: "TestQuery", Description: "Some description", Query: "select * from osquery;", ObserverCanRun: true}) + require.NoError(t, err) + + gsParams := teamScheduleQueryRequest{ScheduledQueryPayload: fleet.ScheduledQueryPayload{QueryID: &qr.ID, Interval: ptr.Uint(42)}} + r := teamScheduleQueryResponse{} + doJSONReq(t, gsParams, "POST", server, fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), token, http.StatusOK, &r) + require.Nil(t, r.Err) + + ts = getTeamScheduleResponse{} + doJSONReq(t, nil, "GET", server, fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), token, http.StatusOK, &ts) + assert.Len(t, ts.Scheduled, 1) + assert.Equal(t, uint(42), ts.Scheduled[0].Interval) + assert.Equal(t, "TestQuery", ts.Scheduled[0].Name) + assert.Equal(t, qr.ID, ts.Scheduled[0].QueryID) + id := ts.Scheduled[0].ID + + modifyResp := modifyTeamScheduleResponse{} + modifyParams := modifyTeamScheduleRequest{ScheduledQueryPayload: fleet.ScheduledQueryPayload{Interval: ptr.Uint(55)}} + doJSONReq( + t, modifyParams, "PATCH", server, + fmt.Sprintf("/api/v1/fleet/team/%d/schedule/%d", team1.ID, id), + token, http.StatusOK, &modifyResp, + ) + + // just to satisfy my paranoia, wanted to make sure the contents of the json would work + doRawReq(t, []byte(`{"interval": 77}`), "PATCH", server, + fmt.Sprintf("/api/v1/fleet/team/%d/schedule/%d", team1.ID, id), + token, http.StatusOK) + + ts = getTeamScheduleResponse{} + doJSONReq(t, nil, "GET", server, fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), token, http.StatusOK, &ts) + assert.Len(t, ts.Scheduled, 1) + assert.Equal(t, uint(77), ts.Scheduled[0].Interval) + + deleteResp := deleteTeamScheduleResponse{} + doJSONReq( + t, nil, "DELETE", server, + fmt.Sprintf("/api/v1/fleet/team/%d/schedule/%d", team1.ID, id), + token, http.StatusOK, &deleteResp, + ) + require.Nil(t, r.Err) + + ts = getTeamScheduleResponse{} + doJSONReq(t, nil, "GET", server, fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), token, http.StatusOK, &ts) + assert.Len(t, ts.Scheduled, 0) +} + func TestLogger(t *testing.T) { buf := new(bytes.Buffer) logger := log.NewJSONLogger(buf) diff --git a/server/service/team_schedule.go b/server/service/team_schedule.go new file mode 100644 index 0000000000..ee9e001d71 --- /dev/null +++ b/server/service/team_schedule.go @@ -0,0 +1,220 @@ +package service + +import ( + "context" + "net/http" + + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/ptr" + kithttp "github.com/go-kit/kit/transport/http" + "gopkg.in/guregu/null.v3" +) + +type getTeamScheduleRequest struct { + TeamID uint `url:"team_id"` + ListOptions fleet.ListOptions `url:"list_options"` +} + +type getTeamScheduleResponse struct { + Scheduled []scheduledQueryResponse `json:"scheduled"` + Err error `json:"error,omitempty"` +} + +func (r getTeamScheduleResponse) error() error { return r.Err } + +func makeGetTeamScheduleEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler { + return newServer( + makeAuthenticatedServiceEndpoint(svc, getTeamScheduleEndpoint), + makeDecoderForOptionsAndIDs(getTeamScheduleRequest{}, "team_id"), + opts, + ) +} + +func getTeamScheduleEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*getTeamScheduleRequest) + resp := getTeamScheduleResponse{Scheduled: []scheduledQueryResponse{}} + queries, err := svc.GetTeamScheduledQueries(ctx, req.TeamID, req.ListOptions) + if err != nil { + return getTeamScheduleResponse{Err: err}, nil + } + for _, q := range queries { + resp.Scheduled = append(resp.Scheduled, scheduledQueryResponse{ + ScheduledQuery: *q, + }) + } + return resp, nil +} + +func (svc Service) GetTeamScheduledQueries(ctx context.Context, teamID uint, opts fleet.ListOptions) ([]*fleet.ScheduledQuery, error) { + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { + return nil, err + } + + gp, err := svc.ds.EnsureTeamPack(teamID) + if err != nil { + return nil, err + } + + return svc.ds.ListScheduledQueriesInPack(gp.ID, opts) +} + +///////////////////////////////////////////////////////////////////////////////// +// Add +///////////////////////////////////////////////////////////////////////////////// + +type teamScheduleQueryRequest struct { + TeamID uint `url:"team_id"` + fleet.ScheduledQueryPayload +} + +type teamScheduleQueryResponse struct { + Scheduled *fleet.ScheduledQuery `json:"scheduled,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r teamScheduleQueryResponse) error() error { return r.Err } + +func makeTeamScheduleQueryEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler { + return newServer( + makeAuthenticatedServiceEndpoint(svc, teamScheduleQueryEndpoint), + makeDecoderForTypeAndIDs(teamScheduleQueryRequest{}, "team_id"), + opts, + ) +} + +func uintValueOrZero(v *uint) uint { + if v == nil { + return 0 + } + return *v +} + +func nullIntToPtrUint(v *null.Int) *uint { + if v == nil { + return nil + } + return ptr.Uint(uint(v.ValueOrZero())) +} + +func teamScheduleQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*teamScheduleQueryRequest) + resp, err := svc.TeamScheduleQuery(ctx, req.TeamID, &fleet.ScheduledQuery{ + QueryID: uintValueOrZero(req.QueryID), + Interval: uintValueOrZero(req.Interval), + Snapshot: req.Snapshot, + Removed: req.Removed, + Platform: req.Platform, + Version: req.Version, + Shard: nullIntToPtrUint(req.Shard), + }) + if err != nil { + return teamScheduleQueryResponse{Err: err}, nil + } + _ = resp + return teamScheduleQueryResponse{}, nil +} + +func (svc Service) TeamScheduleQuery(ctx context.Context, teamID uint, q *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) { + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { + return nil, err + } + + gp, err := svc.ds.EnsureTeamPack(teamID) + if err != nil { + return nil, err + } + q.PackID = gp.ID + + return svc.ScheduleQuery(ctx, q) +} + +///////////////////////////////////////////////////////////////////////////////// +// Modify +///////////////////////////////////////////////////////////////////////////////// + +type modifyTeamScheduleRequest struct { + TeamID uint `url:"team_id"` + ScheduledQueryID uint `url:"scheduled_query_id"` + fleet.ScheduledQueryPayload +} + +type modifyTeamScheduleResponse struct { + Scheduled *fleet.ScheduledQuery `json:"scheduled,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r modifyTeamScheduleResponse) error() error { return r.Err } + +func makeModifyTeamScheduleEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler { + return newServer( + makeAuthenticatedServiceEndpoint(svc, modifyTeamScheduleEndpoint), + makeDecoderForTypeAndIDs(modifyTeamScheduleRequest{}, "team_id", "scheduled_query_id"), + opts, + ) +} + +func modifyTeamScheduleEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*modifyTeamScheduleRequest) + resp, err := svc.ModifyTeamScheduledQueries(ctx, req.TeamID, req.ScheduledQueryID, req.ScheduledQueryPayload) + if err != nil { + return modifyTeamScheduleResponse{Err: err}, nil + } + _ = resp + return modifyTeamScheduleResponse{}, nil +} + +func (svc Service) ModifyTeamScheduledQueries(ctx context.Context, teamID uint, scheduledQueryID uint, query fleet.ScheduledQueryPayload) (*fleet.ScheduledQuery, error) { + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { + return nil, err + } + + gp, err := svc.ds.EnsureTeamPack(teamID) + if err != nil { + return nil, err + } + + query.PackID = ptr.Uint(gp.ID) + + return svc.ModifyScheduledQuery(ctx, scheduledQueryID, query) +} + +///////////////////////////////////////////////////////////////////////////////// +// Delete +///////////////////////////////////////////////////////////////////////////////// + +type deleteTeamScheduleRequest struct { + TeamID uint `url:"team_id"` + ScheduledQueryID uint `url:"scheduled_query_id"` +} + +type deleteTeamScheduleResponse struct { + Scheduled *fleet.ScheduledQuery `json:"scheduled,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r deleteTeamScheduleResponse) error() error { return r.Err } + +func makeDeleteTeamScheduleEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler { + return newServer( + makeAuthenticatedServiceEndpoint(svc, deleteTeamScheduleEndpoint), + makeDecoderForIDs(deleteTeamScheduleRequest{}, "team_id", "scheduled_query_id"), + opts, + ) +} + +func deleteTeamScheduleEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*deleteTeamScheduleRequest) + err := svc.DeleteTeamScheduledQueries(ctx, req.TeamID, req.ScheduledQueryID) + if err != nil { + return deleteTeamScheduleResponse{Err: err}, nil + } + return deleteTeamScheduleResponse{}, nil +} + +func (svc Service) DeleteTeamScheduledQueries(ctx context.Context, teamID uint, scheduledQueryID uint) error { + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { + return err + } + _ = teamID + return svc.DeleteScheduledQuery(ctx, scheduledQueryID) +}