From 73e1c801ee2fcdcea1a390e25ab71e1d21cb61f0 Mon Sep 17 00:00:00 2001 From: Martin Angers Date: Wed, 15 Dec 2021 09:35:40 -0500 Subject: [PATCH] Migrate packs endpoints to new pattern (#3244) --- server/service/endpoint_packs.go | 270 ------------- server/service/handler.go | 48 +-- server/service/handler_test.go | 16 - server/service/integration_core_test.go | 78 +++- server/service/packs.go | 485 ++++++++++++++++++++++++ server/service/packs_test.go | 270 +++++++++++++ server/service/service_packs.go | 242 ------------ server/service/service_packs_test.go | 299 --------------- server/service/transport.go | 2 +- server/service/transport_packs.go | 76 ---- server/service/transport_packs_test.go | 104 ----- 11 files changed, 832 insertions(+), 1058 deletions(-) delete mode 100644 server/service/endpoint_packs.go delete mode 100644 server/service/service_packs.go delete mode 100644 server/service/service_packs_test.go delete mode 100644 server/service/transport_packs.go delete mode 100644 server/service/transport_packs_test.go diff --git a/server/service/endpoint_packs.go b/server/service/endpoint_packs.go deleted file mode 100644 index 7a12a91522..0000000000 --- a/server/service/endpoint_packs.go +++ /dev/null @@ -1,270 +0,0 @@ -package service - -import ( - "context" - - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/go-kit/kit/endpoint" -) - -type packResponse struct { - fleet.Pack - QueryCount uint `json:"query_count"` - - // All current hosts in the pack. Hosts which are selected explicty and - // hosts which are part of a label. - TotalHostsCount uint `json:"total_hosts_count"` - - // IDs of hosts which were explicitly selected. - HostIDs []uint `json:"host_ids"` - LabelIDs []uint `json:"label_ids"` - TeamIDs []uint `json:"team_ids"` -} - -func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack) (*packResponse, error) { - opts := fleet.ListOptions{} - queries, err := svc.GetScheduledQueriesInPack(ctx, pack.ID, opts) - if err != nil { - return nil, err - } - - hostMetrics, err := svc.CountHostsInTargets( - ctx, - nil, - fleet.HostTargets{HostIDs: pack.HostIDs, LabelIDs: pack.LabelIDs, TeamIDs: pack.TeamIDs}, - ) - if err != nil { - return nil, err - } - - return &packResponse{ - Pack: pack, - QueryCount: uint(len(queries)), - TotalHostsCount: hostMetrics.TotalHosts, - HostIDs: pack.HostIDs, - LabelIDs: pack.LabelIDs, - TeamIDs: pack.TeamIDs, - }, nil -} - -//////////////////////////////////////////////////////////////////////////////// -// List Packs -//////////////////////////////////////////////////////////////////////////////// - -type listPacksRequest struct { - ListOptions fleet.ListOptions -} - -type listPacksResponse struct { - Packs []packResponse `json:"packs"` - Err error `json:"error,omitempty"` -} - -func (r listPacksResponse) error() error { return r.Err } - -func makeListPacksEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(listPacksRequest) - packs, err := svc.ListPacks(ctx, fleet.PackListOptions{ListOptions: req.ListOptions, IncludeSystemPacks: false}) - if err != nil { - return getPackResponse{Err: err}, nil - } - - resp := listPacksResponse{Packs: make([]packResponse, len(packs))} - for i, pack := range packs { - packResp, err := packResponseForPack(ctx, svc, *pack) - if err != nil { - return getPackResponse{Err: err}, nil - } - resp.Packs[i] = *packResp - } - return resp, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Create Pack -//////////////////////////////////////////////////////////////////////////////// - -type createPackRequest struct { - payload fleet.PackPayload -} - -type createPackResponse struct { - Pack packResponse `json:"pack,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r createPackResponse) error() error { return r.Err } - -func makeCreatePackEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(createPackRequest) - pack, err := svc.NewPack(ctx, req.payload) - if err != nil { - return createPackResponse{Err: err}, nil - } - - resp, err := packResponseForPack(ctx, svc, *pack) - if err != nil { - return createPackResponse{Err: err}, nil - } - - return createPackResponse{ - Pack: *resp, - }, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Modify Pack -//////////////////////////////////////////////////////////////////////////////// - -type modifyPackRequest struct { - ID uint - payload fleet.PackPayload -} - -type modifyPackResponse struct { - Pack packResponse `json:"pack,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r modifyPackResponse) error() error { return r.Err } - -func makeModifyPackEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(modifyPackRequest) - pack, err := svc.ModifyPack(ctx, req.ID, req.payload) - if err != nil { - return modifyPackResponse{Err: err}, nil - } - - resp, err := packResponseForPack(ctx, svc, *pack) - if err != nil { - return modifyPackResponse{Err: err}, nil - } - - return modifyPackResponse{ - Pack: *resp, - }, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Delete Pack -//////////////////////////////////////////////////////////////////////////////// - -type deletePackRequest struct { - Name string -} - -type deletePackResponse struct { - Err error `json:"error,omitempty"` -} - -func (r deletePackResponse) error() error { return r.Err } - -func makeDeletePackEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(deletePackRequest) - err := svc.DeletePack(ctx, req.Name) - if err != nil { - return deletePackResponse{Err: err}, nil - } - return deletePackResponse{}, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Delete Pack By ID -//////////////////////////////////////////////////////////////////////////////// - -type deletePackByIDRequest struct { - ID uint -} - -type deletePackByIDResponse struct { - Err error `json:"error,omitempty"` -} - -func (r deletePackByIDResponse) error() error { return r.Err } - -func makeDeletePackByIDEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(deletePackByIDRequest) - err := svc.DeletePackByID(ctx, req.ID) - if err != nil { - return deletePackByIDResponse{Err: err}, nil - } - return deletePackByIDResponse{}, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Apply Pack Spec -//////////////////////////////////////////////////////////////////////////////// - -type applyPackSpecsRequest struct { - Specs []*fleet.PackSpec `json:"specs"` -} - -type applyPackSpecsResponse struct { - Err error `json:"error,omitempty"` -} - -func (r applyPackSpecsResponse) error() error { return r.Err } - -func makeApplyPackSpecsEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(applyPackSpecsRequest) - _, err := svc.ApplyPackSpecs(ctx, req.Specs) - if err != nil { - return applyPackSpecsResponse{Err: err}, nil - } - return applyPackSpecsResponse{}, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Get Pack Spec -//////////////////////////////////////////////////////////////////////////////// - -type getPackSpecsResponse struct { - Specs []*fleet.PackSpec `json:"specs"` - Err error `json:"error,omitempty"` -} - -func (r getPackSpecsResponse) error() error { return r.Err } - -func makeGetPackSpecsEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - specs, err := svc.GetPackSpecs(ctx) - if err != nil { - return getPackSpecsResponse{Err: err}, nil - } - return getPackSpecsResponse{Specs: specs}, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Get Pack Spec -//////////////////////////////////////////////////////////////////////////////// - -type getPackSpecResponse struct { - Spec *fleet.PackSpec `json:"specs,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r getPackSpecResponse) error() error { return r.Err } - -func makeGetPackSpecEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(getGenericSpecRequest) - spec, err := svc.GetPackSpec(ctx, req.Name) - if err != nil { - return getPackSpecResponse{Err: err}, nil - } - return getPackSpecResponse{Spec: spec}, nil - } -} diff --git a/server/service/handler.go b/server/service/handler.go index 4856c92e8a..1149574dc6 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -60,19 +60,11 @@ type FleetEndpoints struct { GetQuerySpec endpoint.Endpoint CreateDistributedQueryCampaign endpoint.Endpoint CreateDistributedQueryCampaignByNames endpoint.Endpoint - CreatePack endpoint.Endpoint - ModifyPack endpoint.Endpoint - ListPacks endpoint.Endpoint - DeletePack endpoint.Endpoint - DeletePackByID endpoint.Endpoint GetScheduledQueriesInPack endpoint.Endpoint ScheduleQuery endpoint.Endpoint GetScheduledQuery endpoint.Endpoint ModifyScheduledQuery endpoint.Endpoint DeleteScheduledQuery endpoint.Endpoint - ApplyPackSpecs endpoint.Endpoint - GetPackSpecs endpoint.Endpoint - GetPackSpec endpoint.Endpoint EnrollAgent endpoint.Endpoint GetClientConfig endpoint.Endpoint GetDistributedQueries endpoint.Endpoint @@ -167,19 +159,11 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th GetQuerySpec: authenticatedUser(svc, makeGetQuerySpecEndpoint(svc)), CreateDistributedQueryCampaign: authenticatedUser(svc, makeCreateDistributedQueryCampaignEndpoint(svc)), CreateDistributedQueryCampaignByNames: authenticatedUser(svc, makeCreateDistributedQueryCampaignByNamesEndpoint(svc)), - CreatePack: authenticatedUser(svc, makeCreatePackEndpoint(svc)), - ModifyPack: authenticatedUser(svc, makeModifyPackEndpoint(svc)), - ListPacks: authenticatedUser(svc, makeListPacksEndpoint(svc)), - DeletePack: authenticatedUser(svc, makeDeletePackEndpoint(svc)), - DeletePackByID: authenticatedUser(svc, makeDeletePackByIDEndpoint(svc)), GetScheduledQueriesInPack: authenticatedUser(svc, makeGetScheduledQueriesInPackEndpoint(svc)), ScheduleQuery: authenticatedUser(svc, makeScheduleQueryEndpoint(svc)), GetScheduledQuery: authenticatedUser(svc, makeGetScheduledQueryEndpoint(svc)), ModifyScheduledQuery: authenticatedUser(svc, makeModifyScheduledQueryEndpoint(svc)), DeleteScheduledQuery: authenticatedUser(svc, makeDeleteScheduledQueryEndpoint(svc)), - ApplyPackSpecs: authenticatedUser(svc, makeApplyPackSpecsEndpoint(svc)), - GetPackSpecs: authenticatedUser(svc, makeGetPackSpecsEndpoint(svc)), - GetPackSpec: authenticatedUser(svc, makeGetPackSpecEndpoint(svc)), CreateLabel: authenticatedUser(svc, makeCreateLabelEndpoint(svc)), ModifyLabel: authenticatedUser(svc, makeModifyLabelEndpoint(svc)), GetLabel: authenticatedUser(svc, makeGetLabelEndpoint(svc)), @@ -262,19 +246,11 @@ type fleetHandlers struct { GetQuerySpec http.Handler CreateDistributedQueryCampaign http.Handler CreateDistributedQueryCampaignByNames http.Handler - CreatePack http.Handler - ModifyPack http.Handler - ListPacks http.Handler - DeletePack http.Handler - DeletePackByID http.Handler GetScheduledQueriesInPack http.Handler ScheduleQuery http.Handler GetScheduledQuery http.Handler ModifyScheduledQuery http.Handler DeleteScheduledQuery http.Handler - ApplyPackSpecs http.Handler - GetPackSpecs http.Handler - GetPackSpec http.Handler EnrollAgent http.Handler GetClientConfig http.Handler GetDistributedQueries http.Handler @@ -356,19 +332,11 @@ func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandle GetQuerySpec: newServer(e.GetQuerySpec, decodeGetGenericSpecRequest), CreateDistributedQueryCampaign: newServer(e.CreateDistributedQueryCampaign, decodeCreateDistributedQueryCampaignRequest), CreateDistributedQueryCampaignByNames: newServer(e.CreateDistributedQueryCampaignByNames, decodeCreateDistributedQueryCampaignByNamesRequest), - CreatePack: newServer(e.CreatePack, decodeCreatePackRequest), - ModifyPack: newServer(e.ModifyPack, decodeModifyPackRequest), - ListPacks: newServer(e.ListPacks, decodeListPacksRequest), - DeletePack: newServer(e.DeletePack, decodeDeletePackRequest), - DeletePackByID: newServer(e.DeletePackByID, decodeDeletePackByIDRequest), GetScheduledQueriesInPack: newServer(e.GetScheduledQueriesInPack, decodeGetScheduledQueriesInPackRequest), ScheduleQuery: newServer(e.ScheduleQuery, decodeScheduleQueryRequest), GetScheduledQuery: newServer(e.GetScheduledQuery, decodeGetScheduledQueryRequest), ModifyScheduledQuery: newServer(e.ModifyScheduledQuery, decodeModifyScheduledQueryRequest), DeleteScheduledQuery: newServer(e.DeleteScheduledQuery, decodeDeleteScheduledQueryRequest), - ApplyPackSpecs: newServer(e.ApplyPackSpecs, decodeApplyPackSpecsRequest), - GetPackSpecs: newServer(e.GetPackSpecs, decodeNoParamsRequest), - GetPackSpec: newServer(e.GetPackSpec, decodeGetGenericSpecRequest), EnrollAgent: newServer(e.EnrollAgent, decodeEnrollAgentRequest), GetClientConfig: newServer(e.GetClientConfig, decodeGetClientConfigRequest), GetDistributedQueries: newServer(e.GetDistributedQueries, decodeGetDistributedQueriesRequest), @@ -551,19 +519,11 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) { r.Handle("/api/v1/fleet/queries/run", h.CreateDistributedQueryCampaign).Methods("POST").Name("create_distributed_query_campaign") r.Handle("/api/v1/fleet/queries/run_by_names", h.CreateDistributedQueryCampaignByNames).Methods("POST").Name("create_distributed_query_campaign_by_names") - r.Handle("/api/v1/fleet/packs", h.CreatePack).Methods("POST").Name("create_pack") - r.Handle("/api/v1/fleet/packs/{id:[0-9]+}", h.ModifyPack).Methods("PATCH").Name("modify_pack") - r.Handle("/api/v1/fleet/packs", h.ListPacks).Methods("GET").Name("list_packs") - r.Handle("/api/v1/fleet/packs/{name}", h.DeletePack).Methods("DELETE").Name("delete_pack") - r.Handle("/api/v1/fleet/packs/id/{id:[0-9]+}", h.DeletePackByID).Methods("DELETE").Name("delete_pack_by_id") r.Handle("/api/v1/fleet/packs/{id:[0-9]+}/scheduled", h.GetScheduledQueriesInPack).Methods("GET").Name("get_scheduled_queries_in_pack") r.Handle("/api/v1/fleet/schedule", h.ScheduleQuery).Methods("POST").Name("schedule_query") r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.GetScheduledQuery).Methods("GET").Name("get_scheduled_query") r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.ModifyScheduledQuery).Methods("PATCH").Name("modify_scheduled_query") r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.DeleteScheduledQuery).Methods("DELETE").Name("delete_scheduled_query") - r.Handle("/api/v1/fleet/spec/packs", h.ApplyPackSpecs).Methods("POST").Name("apply_pack_specs") - r.Handle("/api/v1/fleet/spec/packs", h.GetPackSpecs).Methods("GET").Name("get_pack_specs") - r.Handle("/api/v1/fleet/spec/packs/{name}", h.GetPackSpec).Methods("GET").Name("get_pack_spec") r.Handle("/api/v1/fleet/labels", h.CreateLabel).Methods("POST").Name("create_label") r.Handle("/api/v1/fleet/labels/{id:[0-9]+}", h.ModifyLabel).Methods("PATCH").Name("modify_label") @@ -642,6 +602,14 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht e.POST("/api/v1/fleet/spec/policies", applyPolicySpecsEndpoint, applyPolicySpecsRequest{}) e.GET("/api/v1/fleet/packs/{id:[0-9]+}", getPackEndpoint, getPackRequest{}) + e.POST("/api/v1/fleet/packs", createPackEndpoint, createPackRequest{}) + e.PATCH("/api/v1/fleet/packs/{id:[0-9]+}", modifyPackEndpoint, modifyPackRequest{}) + e.GET("/api/v1/fleet/packs", listPacksEndpoint, listPacksRequest{}) + e.DELETE("/api/v1/fleet/packs/{name}", deletePackEndpoint, deletePackRequest{}) + e.DELETE("/api/v1/fleet/packs/id/{id:[0-9]+}", deletePackByIDEndpoint, deletePackByIDRequest{}) + e.POST("/api/v1/fleet/spec/packs", applyPackSpecsEndpoint, applyPackSpecsRequest{}) + e.GET("/api/v1/fleet/spec/packs", getPackSpecsEndpoint, nil) + e.GET("/api/v1/fleet/spec/packs/{name}", getPackSpecEndpoint, getGenericSpecRequest{}) e.GET("/api/v1/fleet/software", listSoftwareEndpoint, listSoftwareRequest{}) e.GET("/api/v1/fleet/software/count", countSoftwareEndpoint, countSoftwareRequest{}) diff --git a/server/service/handler_test.go b/server/service/handler_test.go index f46037c1da..31b312bb12 100644 --- a/server/service/handler_test.go +++ b/server/service/handler_test.go @@ -113,22 +113,6 @@ func TestAPIRoutes(t *testing.T) { verb: "POST", uri: "/api/v1/fleet/queries/run", }, - { - verb: "GET", - uri: "/api/v1/fleet/packs", - }, - { - verb: "POST", - uri: "/api/v1/fleet/packs", - }, - { - verb: "PATCH", - uri: "/api/v1/fleet/packs/1", - }, - { - verb: "DELETE", - uri: "/api/v1/fleet/packs/1", - }, { verb: "GET", uri: "/api/v1/fleet/packs/1/scheduled", diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 8cfe48c90c..d028408ff3 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -6,7 +6,9 @@ import ( "fmt" "io/ioutil" "net/http" + "net/url" "reflect" + "strings" "testing" "time" @@ -627,20 +629,76 @@ func (s *integrationTestSuite) TestCountSoftware() { assert.Equal(t, 1, resp.Count) } -func (s *integrationTestSuite) TestGetPack() { +func (s *integrationTestSuite) TestPacks() { t := s.T() - pack := &fleet.Pack{ - Name: t.Name(), - } - pack, err := s.ds.NewPack(context.Background(), pack) - require.NoError(t, err) - var packResp getPackResponse - s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/packs/%d", pack.ID), nil, http.StatusOK, &packResp) - require.Equal(t, packResp.Pack.ID, pack.ID) + // get non-existing pack + s.Do("GET", "/api/v1/fleet/packs/999", nil, http.StatusNotFound) - s.Do("GET", fmt.Sprintf("/api/v1/fleet/packs/%d", pack.ID+1), nil, http.StatusNotFound) + // create some packs + packs := make([]fleet.Pack, 3) + for i := range packs { + req := &createPackRequest{ + PackPayload: fleet.PackPayload{ + Name: ptr.String(fmt.Sprintf("%s_%d", strings.ReplaceAll(t.Name(), "/", "_"), i)), + }, + } + + var createResp createPackResponse + s.DoJSON("POST", "/api/v1/fleet/packs", req, http.StatusOK, &createResp) + packs[i] = createResp.Pack.Pack + } + + // get existing pack + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/packs/%d", packs[0].ID), nil, http.StatusOK, &packResp) + require.Equal(t, packs[0].ID, packResp.Pack.ID) + + // list packs + var listResp listPacksResponse + s.DoJSON("GET", "/api/v1/fleet/packs", nil, http.StatusOK, &listResp, "per_page", "2", "order_key", "name") + require.Len(t, listResp.Packs, 2) + assert.Equal(t, packs[0].ID, listResp.Packs[0].ID) + assert.Equal(t, packs[1].ID, listResp.Packs[1].ID) + + // get page 1 + s.DoJSON("GET", "/api/v1/fleet/packs", nil, http.StatusOK, &listResp, "page", "1", "per_page", "2", "order_key", "name") + require.Len(t, listResp.Packs, 1) + assert.Equal(t, packs[2].ID, listResp.Packs[0].ID) + + // get page 2, empty + s.DoJSON("GET", "/api/v1/fleet/packs", nil, http.StatusOK, &listResp, "page", "2", "per_page", "2", "order_key", "name") + require.Len(t, listResp.Packs, 0) + + var delResp deletePackResponse + // delete non-existing pack by name + s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/packs/%s", "zzz"), nil, http.StatusNotFound, &delResp) + + // delete existing pack by name + s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/packs/%s", url.PathEscape(packs[0].Name)), nil, http.StatusOK, &delResp) + + // delete non-existing pack by id + var delIDResp deletePackByIDResponse + s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/packs/id/%d", packs[2].ID+1), nil, http.StatusNotFound, &delIDResp) + + // delete existing pack by id + s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/packs/id/%d", packs[1].ID), nil, http.StatusOK, &delIDResp) + + var modResp modifyPackResponse + // modify non-existing pack + req := &fleet.PackPayload{Name: ptr.String("updated_" + packs[2].Name)} + s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/packs/%d", packs[2].ID+1), req, http.StatusNotFound, &modResp) + + // modify existing pack + req = &fleet.PackPayload{Name: ptr.String("updated_" + packs[2].Name)} + s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/packs/%d", packs[2].ID), req, http.StatusOK, &modResp) + require.Equal(t, packs[2].ID, modResp.Pack.ID) + require.Contains(t, modResp.Pack.Name, "updated_") + + // list packs, only packs[2] remains + s.DoJSON("GET", "/api/v1/fleet/packs", nil, http.StatusOK, &listResp, "per_page", "2", "order_key", "name") + require.Len(t, listResp.Packs, 1) + assert.Equal(t, packs[2].ID, listResp.Packs[0].ID) } func (s *integrationTestSuite) TestListHosts() { diff --git a/server/service/packs.go b/server/service/packs.go index 310cf60520..eb19e0aab6 100644 --- a/server/service/packs.go +++ b/server/service/packs.go @@ -2,10 +2,52 @@ package service import ( "context" + "fmt" + "github.com/fleetdm/fleet/v4/server/authz" "github.com/fleetdm/fleet/v4/server/fleet" ) +type packResponse struct { + fleet.Pack + QueryCount uint `json:"query_count"` + + // All current hosts in the pack. Hosts which are selected explicty and + // hosts which are part of a label. + TotalHostsCount uint `json:"total_hosts_count"` + + // IDs of hosts which were explicitly selected. + HostIDs []uint `json:"host_ids"` + LabelIDs []uint `json:"label_ids"` + TeamIDs []uint `json:"team_ids"` +} + +func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack) (*packResponse, error) { + opts := fleet.ListOptions{} + queries, err := svc.GetScheduledQueriesInPack(ctx, pack.ID, opts) + if err != nil { + return nil, err + } + + hostMetrics, err := svc.CountHostsInTargets( + ctx, + nil, + fleet.HostTargets{HostIDs: pack.HostIDs, LabelIDs: pack.LabelIDs, TeamIDs: pack.TeamIDs}, + ) + if err != nil { + return nil, err + } + + return &packResponse{ + Pack: pack, + QueryCount: uint(len(queries)), + TotalHostsCount: hostMetrics.TotalHosts, + HostIDs: pack.HostIDs, + LabelIDs: pack.LabelIDs, + TeamIDs: pack.TeamIDs, + }, nil +} + //////////////////////////////////////////////////////////////////////////////// // Get Pack //////////////////////////////////////////////////////////////////////////////// @@ -45,3 +87,446 @@ func (svc *Service) GetPack(ctx context.Context, id uint) (*fleet.Pack, error) { return svc.ds.Pack(ctx, id) } + +//////////////////////////////////////////////////////////////////////////////// +// Create Pack +//////////////////////////////////////////////////////////////////////////////// + +type createPackRequest struct { + fleet.PackPayload +} + +type createPackResponse struct { + Pack packResponse `json:"pack,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r createPackResponse) error() error { return r.Err } + +func createPackEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*createPackRequest) + pack, err := svc.NewPack(ctx, req.PackPayload) + if err != nil { + return createPackResponse{Err: err}, nil + } + + resp, err := packResponseForPack(ctx, svc, *pack) + if err != nil { + return createPackResponse{Err: err}, nil + } + + return createPackResponse{ + Pack: *resp, + }, nil +} + +func (svc *Service) NewPack(ctx context.Context, p fleet.PackPayload) (*fleet.Pack, error) { + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { + return nil, err + } + + var pack fleet.Pack + + if p.Name != nil { + pack.Name = *p.Name + } + + if p.Description != nil { + pack.Description = *p.Description + } + + if p.Platform != nil { + pack.Platform = *p.Platform + } + + if p.Disabled != nil { + pack.Disabled = *p.Disabled + } + + if p.HostIDs != nil { + pack.HostIDs = *p.HostIDs + } + + if p.LabelIDs != nil { + pack.LabelIDs = *p.LabelIDs + } + + if p.TeamIDs != nil { + pack.TeamIDs = *p.TeamIDs + } + + _, err := svc.ds.NewPack(ctx, &pack) + if err != nil { + return nil, err + } + + if err := svc.ds.NewActivity( + ctx, + authz.UserFromContext(ctx), + fleet.ActivityTypeCreatedPack, + &map[string]interface{}{"pack_id": pack.ID, "pack_name": pack.Name}, + ); err != nil { + return nil, err + } + + return &pack, nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Modify Pack +//////////////////////////////////////////////////////////////////////////////// + +type modifyPackRequest struct { + ID uint `json:"-" url:"id"` + fleet.PackPayload +} + +type modifyPackResponse struct { + Pack packResponse `json:"pack,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r modifyPackResponse) error() error { return r.Err } + +func modifyPackEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*modifyPackRequest) + pack, err := svc.ModifyPack(ctx, req.ID, req.PackPayload) + if err != nil { + return modifyPackResponse{Err: err}, nil + } + + resp, err := packResponseForPack(ctx, svc, *pack) + if err != nil { + return modifyPackResponse{Err: err}, nil + } + + return modifyPackResponse{ + Pack: *resp, + }, nil +} + +func (svc *Service) ModifyPack(ctx context.Context, id uint, p fleet.PackPayload) (*fleet.Pack, error) { + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { + return nil, err + } + + pack, err := svc.ds.Pack(ctx, id) + if err != nil { + return nil, err + } + + if p.Name != nil && pack.EditablePackType() { + pack.Name = *p.Name + } + + if p.Description != nil && pack.EditablePackType() { + pack.Description = *p.Description + } + + if p.Platform != nil { + pack.Platform = *p.Platform + } + + if p.Disabled != nil { + pack.Disabled = *p.Disabled + } + + if p.HostIDs != nil && pack.EditablePackType() { + pack.HostIDs = *p.HostIDs + } + + if p.LabelIDs != nil && pack.EditablePackType() { + pack.LabelIDs = *p.LabelIDs + } + + if p.TeamIDs != nil && pack.EditablePackType() { + pack.TeamIDs = *p.TeamIDs + } + + err = svc.ds.SavePack(ctx, pack) + if err != nil { + return nil, err + } + + if err := svc.ds.NewActivity( + ctx, + authz.UserFromContext(ctx), + fleet.ActivityTypeEditedPack, + &map[string]interface{}{"pack_id": pack.ID, "pack_name": pack.Name}, + ); err != nil { + return nil, err + } + + return pack, err +} + +//////////////////////////////////////////////////////////////////////////////// +// List Packs +//////////////////////////////////////////////////////////////////////////////// + +type listPacksRequest struct { + ListOptions fleet.ListOptions `url:"list_options"` +} + +type listPacksResponse struct { + Packs []packResponse `json:"packs"` + Err error `json:"error,omitempty"` +} + +func (r listPacksResponse) error() error { return r.Err } + +func listPacksEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*listPacksRequest) + packs, err := svc.ListPacks(ctx, fleet.PackListOptions{ListOptions: req.ListOptions, IncludeSystemPacks: false}) + if err != nil { + return getPackResponse{Err: err}, nil + } + + resp := listPacksResponse{Packs: make([]packResponse, len(packs))} + for i, pack := range packs { + packResp, err := packResponseForPack(ctx, svc, *pack) + if err != nil { + return getPackResponse{Err: err}, nil + } + resp.Packs[i] = *packResp + } + return resp, nil +} + +func (svc *Service) ListPacks(ctx context.Context, opt fleet.PackListOptions) ([]*fleet.Pack, error) { + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { + return nil, err + } + + return svc.ds.ListPacks(ctx, opt) +} + +//////////////////////////////////////////////////////////////////////////////// +// Delete Pack +//////////////////////////////////////////////////////////////////////////////// + +type deletePackRequest struct { + Name string `url:"name"` +} + +type deletePackResponse struct { + Err error `json:"error,omitempty"` +} + +func (r deletePackResponse) error() error { return r.Err } + +func deletePackEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*deletePackRequest) + err := svc.DeletePack(ctx, req.Name) + if err != nil { + return deletePackResponse{Err: err}, nil + } + return deletePackResponse{}, nil +} + +func (svc *Service) DeletePack(ctx context.Context, name string) error { + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { + return err + } + + pack, _, err := svc.ds.PackByName(ctx, name) + if err != nil { + return err + } + // if there is a pack by this name, ensure it is not type Global or Team + if pack != nil && !pack.EditablePackType() { + return fmt.Errorf("cannot delete pack_type %s", *pack.Type) + } + + if err := svc.ds.DeletePack(ctx, name); err != nil { + return err + } + + return svc.ds.NewActivity( + ctx, + authz.UserFromContext(ctx), + fleet.ActivityTypeDeletedPack, + &map[string]interface{}{"pack_name": name}, + ) +} + +//////////////////////////////////////////////////////////////////////////////// +// Delete Pack By ID +//////////////////////////////////////////////////////////////////////////////// + +type deletePackByIDRequest struct { + ID uint `url:"id"` +} + +type deletePackByIDResponse struct { + Err error `json:"error,omitempty"` +} + +func (r deletePackByIDResponse) error() error { return r.Err } + +func deletePackByIDEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*deletePackByIDRequest) + err := svc.DeletePackByID(ctx, req.ID) + if err != nil { + return deletePackByIDResponse{Err: err}, nil + } + return deletePackByIDResponse{}, nil +} + +func (svc *Service) DeletePackByID(ctx context.Context, id uint) error { + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { + return err + } + + pack, err := svc.ds.Pack(ctx, id) + if err != nil { + return err + } + if pack != nil && !pack.EditablePackType() { + return fmt.Errorf("cannot delete pack_type %s", *pack.Type) + } + if err := svc.ds.DeletePack(ctx, pack.Name); err != nil { + return err + } + + return svc.ds.NewActivity( + ctx, + authz.UserFromContext(ctx), + fleet.ActivityTypeDeletedPack, + &map[string]interface{}{"pack_name": pack.Name}, + ) +} + +//////////////////////////////////////////////////////////////////////////////// +// Apply Pack Spec +//////////////////////////////////////////////////////////////////////////////// + +type applyPackSpecsRequest struct { + Specs []*fleet.PackSpec `json:"specs"` +} + +type applyPackSpecsResponse struct { + Err error `json:"error,omitempty"` +} + +func (r applyPackSpecsResponse) error() error { return r.Err } + +func applyPackSpecsEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*applyPackSpecsRequest) + _, err := svc.ApplyPackSpecs(ctx, req.Specs) + if err != nil { + return applyPackSpecsResponse{Err: err}, nil + } + return applyPackSpecsResponse{}, nil +} + +func (svc *Service) ApplyPackSpecs(ctx context.Context, specs []*fleet.PackSpec) ([]*fleet.PackSpec, error) { + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { + return nil, err + } + + packs, err := svc.ds.ListPacks(ctx, fleet.PackListOptions{IncludeSystemPacks: true}) + if err != nil { + return nil, err + } + + namePacks := make(map[string]*fleet.Pack, len(packs)) + for _, pack := range packs { + namePacks[pack.Name] = pack + } + + var result []*fleet.PackSpec + + // loop over incoming specs filtering out possible edits to Global or Team Packs + for _, spec := range specs { + // see for known limitations https://github.com/fleetdm/fleet/pull/1558#discussion_r684218301 + // check to see if incoming spec is already in the list of packs + if p, ok := namePacks[spec.Name]; ok { + // as long as pack is editable, we'll apply it + if p.EditablePackType() { + result = append(result, spec) + } + } else { + // incoming spec is new, let's apply it + result = append(result, spec) + } + } + + if err := svc.ds.ApplyPackSpecs(ctx, result); err != nil { + return nil, err + } + + return result, svc.ds.NewActivity( + ctx, + authz.UserFromContext(ctx), + fleet.ActivityTypeAppliedSpecPack, + &map[string]interface{}{}, + ) +} + +//////////////////////////////////////////////////////////////////////////////// +// Get Pack Specs +//////////////////////////////////////////////////////////////////////////////// + +type getPackSpecsResponse struct { + Specs []*fleet.PackSpec `json:"specs"` + Err error `json:"error,omitempty"` +} + +func (r getPackSpecsResponse) error() error { return r.Err } + +func getPackSpecsEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + specs, err := svc.GetPackSpecs(ctx) + if err != nil { + return getPackSpecsResponse{Err: err}, nil + } + return getPackSpecsResponse{Specs: specs}, nil +} + +func (svc *Service) GetPackSpecs(ctx context.Context) ([]*fleet.PackSpec, error) { + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { + return nil, err + } + + return svc.ds.GetPackSpecs(ctx) +} + +//////////////////////////////////////////////////////////////////////////////// +// Get Pack Spec +//////////////////////////////////////////////////////////////////////////////// + +type getPackSpecResponse struct { + Spec *fleet.PackSpec `json:"specs,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r getPackSpecResponse) error() error { return r.Err } + +func getPackSpecEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*getGenericSpecRequest) + spec, err := svc.GetPackSpec(ctx, req.Name) + if err != nil { + return getPackSpecResponse{Err: err}, nil + } + return getPackSpecResponse{Spec: spec}, nil +} + +func (svc *Service) GetPackSpec(ctx context.Context, name string) (*fleet.PackSpec, error) { + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { + return nil, err + } + + return svc.ds.GetPackSpec(ctx, name) +} + +//////////////////////////////////////////////////////////////////////////////// +// List Packs For Host, not exposed via an endpoint +//////////////////////////////////////////////////////////////////////////////// + +func (svc *Service) ListPacksForHost(ctx context.Context, hid uint) ([]*fleet.Pack, error) { + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { + return nil, err + } + + return svc.ds.ListPacksForHost(ctx, hid) +} diff --git a/server/service/packs_test.go b/server/service/packs_test.go index 69bf4da2e8..60686c8fcc 100644 --- a/server/service/packs_test.go +++ b/server/service/packs_test.go @@ -5,9 +5,12 @@ import ( "testing" "github.com/fleetdm/fleet/v4/server/authz" + "github.com/fleetdm/fleet/v4/server/datastore/mysql" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mock" + "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/test" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -30,3 +33,270 @@ func TestGetPack(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), authz.ForbiddenErrorMessage) } + +func TestNewPackSavesTargets(t *testing.T) { + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + + ds.NewPackFunc = func(ctx context.Context, pack *fleet.Pack, opts ...fleet.OptionalArg) (*fleet.Pack, error) { + return pack, nil + } + ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error { + return nil + } + + packPayload := fleet.PackPayload{ + Name: ptr.String("foo"), + HostIDs: &[]uint{123}, + LabelIDs: &[]uint{456}, + TeamIDs: &[]uint{789}, + } + pack, err := svc.NewPack(test.UserContext(test.UserAdmin), packPayload) + require.NoError(t, err) + + require.Len(t, pack.HostIDs, 1) + require.Len(t, pack.LabelIDs, 1) + require.Len(t, pack.TeamIDs, 1) + assert.Equal(t, uint(123), pack.HostIDs[0]) + assert.Equal(t, uint(456), pack.LabelIDs[0]) + assert.Equal(t, uint(789), pack.TeamIDs[0]) + assert.True(t, ds.NewPackFuncInvoked) + assert.True(t, ds.NewActivityFuncInvoked) +} + +func TestPacksWithDS(t *testing.T) { + ds := mysql.CreateMySQLDS(t) + + cases := []struct { + name string + fn func(t *testing.T, ds *mysql.Datastore) + }{ + {"ModifyPack", testPacksModifyPack}, + {"ListPacks", testPacksListPacks}, + {"DeletePack", testPacksDeletePack}, + {"DeletePackByID", testPacksDeletePackByID}, + {"ApplyPackSpecs", testPacksApplyPackSpecs}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + defer mysql.TruncateTables(t, ds) + c.fn(t, ds) + }) + } +} + +func testPacksModifyPack(t *testing.T, ds *mysql.Datastore) { + svc := newTestService(ds, nil, nil) + test.AddAllHostsLabel(t, ds) + users := createTestUsers(t, ds) + + globalPack, err := ds.EnsureGlobalPack(context.Background()) + require.NoError(t, err) + + labelids := []uint{1, 2, 3} + hostids := []uint{4, 5, 6} + teamids := []uint{7, 8, 9} + packPayload := fleet.PackPayload{ + Name: ptr.String("foo"), + Description: ptr.String("bar"), + LabelIDs: &labelids, + HostIDs: &hostids, + TeamIDs: &teamids, + } + + user := users["admin1@example.com"] + pack, _ := svc.ModifyPack(test.UserContext(&user), globalPack.ID, packPayload) + + require.Equal(t, "Global", pack.Name, "name for global pack should not change") + require.Equal(t, "Global pack", pack.Description, "description for global pack should not change") + require.Len(t, pack.LabelIDs, 1) + require.Len(t, pack.HostIDs, 0) + require.Len(t, pack.TeamIDs, 0) +} + +func testPacksListPacks(t *testing.T, ds *mysql.Datastore) { + svc := newTestService(ds, nil, nil) + + queries, err := svc.ListPacks(test.UserContext(test.UserAdmin), fleet.PackListOptions{IncludeSystemPacks: false}) + require.NoError(t, err) + assert.Len(t, queries, 0) + + _, err = ds.NewPack(context.Background(), &fleet.Pack{ + Name: "foo", + }) + require.NoError(t, err) + + queries, err = svc.ListPacks(test.UserContext(test.UserAdmin), fleet.PackListOptions{IncludeSystemPacks: false}) + require.NoError(t, err) + assert.Len(t, queries, 1) +} + +func testPacksDeletePack(t *testing.T, ds *mysql.Datastore) { + test.AddAllHostsLabel(t, ds) + + gp, err := ds.EnsureGlobalPack(context.Background()) + require.NoError(t, err) + + users := createTestUsers(t, ds) + user := users["admin1@example.com"] + + team1, err := ds.NewTeam(context.Background(), &fleet.Team{ + ID: 42, + Name: "team1", + Description: "desc team1", + }) + require.NoError(t, err) + + tp, err := ds.EnsureTeamPack(context.Background(), team1.ID) + require.NoError(t, err) + + type args struct { + ctx context.Context + name string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "cannot delete global pack", + args: args{ + ctx: test.UserContext(&user), + name: gp.Name, + }, + wantErr: true, + }, + { + name: "cannot delete team pack", + args: args{ + ctx: test.UserContext(&user), + name: tp.Name, + }, + wantErr: true, + }, + { + name: "delete pack that doesn't exist", + args: args{ + ctx: test.UserContext(&user), + name: "foo", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := newTestService(ds, nil, nil) + if err := svc.DeletePack(tt.args.ctx, tt.args.name); (err != nil) != tt.wantErr { + t.Errorf("DeletePack() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func testPacksDeletePackByID(t *testing.T, ds *mysql.Datastore) { + test.AddAllHostsLabel(t, ds) + + globalPack, err := ds.EnsureGlobalPack(context.Background()) + require.NoError(t, err) + + type args struct { + ctx context.Context + id uint + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "cannot delete global pack", + args: args{ + ctx: test.UserContext(test.UserAdmin), + id: globalPack.ID, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := newTestService(ds, nil, nil) + if err := svc.DeletePackByID(tt.args.ctx, tt.args.id); (err != nil) != tt.wantErr { + t.Errorf("DeletePackByID() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func testPacksApplyPackSpecs(t *testing.T, ds *mysql.Datastore) { + test.AddAllHostsLabel(t, ds) + + global, err := ds.EnsureGlobalPack(context.Background()) + require.NoError(t, err) + + users := createTestUsers(t, ds) + user := users["admin1@example.com"] + + team1, err := ds.NewTeam(context.Background(), &fleet.Team{ + ID: 42, + Name: "team1", + Description: "desc team1", + }) + require.NoError(t, err) + + teamPack, err := ds.EnsureTeamPack(context.Background(), team1.ID) + require.NoError(t, err) + + type args struct { + ctx context.Context + specs []*fleet.PackSpec + } + tests := []struct { + name string + args args + want []*fleet.PackSpec + wantErr bool + }{ + { + name: "cannot modify global pack", + args: args{ + ctx: test.UserContext(&user), + specs: []*fleet.PackSpec{ + {Name: global.Name, Description: "bar", Platform: "baz"}, + {Name: "Foo Pack", Description: "Foo Desc", Platform: "MacOS"}, + {Name: "Bar Pack", Description: "Bar Desc", Platform: "MacOS"}, + }, + }, + want: []*fleet.PackSpec{ + {Name: "Foo Pack", Description: "Foo Desc", Platform: "MacOS"}, + {Name: "Bar Pack", Description: "Bar Desc", Platform: "MacOS"}, + }, + wantErr: false, + }, + { + name: "cannot modify team pack", + args: args{ + ctx: test.UserContext(&user), + specs: []*fleet.PackSpec{ + {Name: teamPack.Name, Description: "Desc", Platform: "windows"}, + {Name: "Test", Description: "Test Desc", Platform: "linux"}, + }, + }, + want: []*fleet.PackSpec{ + {Name: "Test", Description: "Test Desc", Platform: "linux"}, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := newTestService(ds, nil, nil) + got, err := svc.ApplyPackSpecs(tt.args.ctx, tt.args.specs) + if (err != nil) != tt.wantErr { + t.Errorf("ApplyPackSpecs() error = %v, wantErr %v", err, tt.wantErr) + return + } + require.Equal(t, tt.want, got) + }) + } +} diff --git a/server/service/service_packs.go b/server/service/service_packs.go deleted file mode 100644 index a6a19213fb..0000000000 --- a/server/service/service_packs.go +++ /dev/null @@ -1,242 +0,0 @@ -package service - -import ( - "context" - "fmt" - - "github.com/fleetdm/fleet/v4/server/authz" - "github.com/fleetdm/fleet/v4/server/fleet" -) - -func (svc *Service) ApplyPackSpecs(ctx context.Context, specs []*fleet.PackSpec) ([]*fleet.PackSpec, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { - return nil, err - } - - packs, err := svc.ds.ListPacks(ctx, fleet.PackListOptions{IncludeSystemPacks: true}) - if err != nil { - return nil, err - } - - namePacks := make(map[string]*fleet.Pack, len(packs)) - for _, pack := range packs { - namePacks[pack.Name] = pack - } - - var result []*fleet.PackSpec - - // loop over incoming specs filtering out possible edits to Global or Team Packs - for _, spec := range specs { - // see for known limitations https://github.com/fleetdm/fleet/pull/1558#discussion_r684218301 - // check to see if incoming spec is already in the list of packs - if p, ok := namePacks[spec.Name]; ok { - // as long as pack is editable, we'll apply it - if p.EditablePackType() { - result = append(result, spec) - } - } else { - // incoming spec is new, let's apply it - result = append(result, spec) - } - } - - if err := svc.ds.ApplyPackSpecs(ctx, result); err != nil { - return nil, err - } - - return result, svc.ds.NewActivity( - ctx, - authz.UserFromContext(ctx), - fleet.ActivityTypeAppliedSpecPack, - &map[string]interface{}{}, - ) -} - -func (svc *Service) GetPackSpecs(ctx context.Context) ([]*fleet.PackSpec, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { - return nil, err - } - - return svc.ds.GetPackSpecs(ctx) -} - -func (svc *Service) GetPackSpec(ctx context.Context, name string) (*fleet.PackSpec, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { - return nil, err - } - - return svc.ds.GetPackSpec(ctx, name) -} - -func (svc *Service) ListPacks(ctx context.Context, opt fleet.PackListOptions) ([]*fleet.Pack, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { - return nil, err - } - - return svc.ds.ListPacks(ctx, opt) -} - -func (svc *Service) NewPack(ctx context.Context, p fleet.PackPayload) (*fleet.Pack, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { - return nil, err - } - - var pack fleet.Pack - - if p.Name != nil { - pack.Name = *p.Name - } - - if p.Description != nil { - pack.Description = *p.Description - } - - if p.Platform != nil { - pack.Platform = *p.Platform - } - - if p.Disabled != nil { - pack.Disabled = *p.Disabled - } - - if p.HostIDs != nil { - pack.HostIDs = *p.HostIDs - } - - if p.LabelIDs != nil { - pack.LabelIDs = *p.LabelIDs - } - - if p.TeamIDs != nil { - pack.TeamIDs = *p.TeamIDs - } - - _, err := svc.ds.NewPack(ctx, &pack) - if err != nil { - return nil, err - } - - if err := svc.ds.NewActivity( - ctx, - authz.UserFromContext(ctx), - fleet.ActivityTypeCreatedPack, - &map[string]interface{}{"pack_id": pack.ID, "pack_name": pack.Name}, - ); err != nil { - return nil, err - } - - return &pack, nil -} - -func (svc *Service) ModifyPack(ctx context.Context, id uint, p fleet.PackPayload) (*fleet.Pack, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { - return nil, err - } - - pack, err := svc.ds.Pack(ctx, id) - if err != nil { - return nil, err - } - - if p.Name != nil && pack.EditablePackType() { - pack.Name = *p.Name - } - - if p.Description != nil && pack.EditablePackType() { - pack.Description = *p.Description - } - - if p.Platform != nil { - pack.Platform = *p.Platform - } - - if p.Disabled != nil { - pack.Disabled = *p.Disabled - } - - if p.HostIDs != nil && pack.EditablePackType() { - pack.HostIDs = *p.HostIDs - } - - if p.LabelIDs != nil && pack.EditablePackType() { - pack.LabelIDs = *p.LabelIDs - } - - if p.TeamIDs != nil && pack.EditablePackType() { - pack.TeamIDs = *p.TeamIDs - } - - err = svc.ds.SavePack(ctx, pack) - if err != nil { - return nil, err - } - - if err := svc.ds.NewActivity( - ctx, - authz.UserFromContext(ctx), - fleet.ActivityTypeEditedPack, - &map[string]interface{}{"pack_id": pack.ID, "pack_name": pack.Name}, - ); err != nil { - return nil, err - } - - return pack, err -} - -func (svc *Service) DeletePack(ctx context.Context, name string) error { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { - return err - } - - pack, _, err := svc.ds.PackByName(ctx, name) - if err != nil { - return err - } - // if there is a pack by this name, ensure it is not type Global or Team - if pack != nil && !pack.EditablePackType() { - return fmt.Errorf("cannot delete pack_type %s", *pack.Type) - } - - if err := svc.ds.DeletePack(ctx, name); err != nil { - return err - } - - return svc.ds.NewActivity( - ctx, - authz.UserFromContext(ctx), - fleet.ActivityTypeDeletedPack, - &map[string]interface{}{"pack_name": name}, - ) -} - -func (svc *Service) DeletePackByID(ctx context.Context, id uint) error { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { - return err - } - - pack, err := svc.ds.Pack(ctx, id) - if err != nil { - return err - } - if pack != nil && !pack.EditablePackType() { - return fmt.Errorf("cannot delete pack_type %s", *pack.Type) - } - if err := svc.ds.DeletePack(ctx, pack.Name); err != nil { - return err - } - - return svc.ds.NewActivity( - ctx, - authz.UserFromContext(ctx), - fleet.ActivityTypeDeletedPack, - &map[string]interface{}{"pack_name": pack.Name}, - ) -} - -func (svc *Service) ListPacksForHost(ctx context.Context, hid uint) ([]*fleet.Pack, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { - return nil, err - } - - return svc.ds.ListPacksForHost(ctx, hid) -} diff --git a/server/service/service_packs_test.go b/server/service/service_packs_test.go deleted file mode 100644 index 5e5d718a1b..0000000000 --- a/server/service/service_packs_test.go +++ /dev/null @@ -1,299 +0,0 @@ -package service - -import ( - "context" - "testing" - - "github.com/fleetdm/fleet/v4/server/datastore/mysql" - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/fleetdm/fleet/v4/server/mock" - "github.com/fleetdm/fleet/v4/server/ptr" - "github.com/fleetdm/fleet/v4/server/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestServiceListPacks(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - svc := newTestService(ds, nil, nil) - - queries, err := svc.ListPacks(test.UserContext(test.UserAdmin), fleet.PackListOptions{IncludeSystemPacks: false}) - assert.Nil(t, err) - assert.Len(t, queries, 0) - - _, err = ds.NewPack(context.Background(), &fleet.Pack{ - Name: "foo", - }) - assert.Nil(t, err) - - queries, err = svc.ListPacks(test.UserContext(test.UserAdmin), fleet.PackListOptions{IncludeSystemPacks: false}) - assert.Nil(t, err) - assert.Len(t, queries, 1) -} - -func TestNewSavesTargets(t *testing.T) { - ds := new(mock.Store) - svc := newTestService(ds, nil, nil) - - ds.NewPackFunc = func(ctx context.Context, pack *fleet.Pack, opts ...fleet.OptionalArg) (*fleet.Pack, error) { - return pack, nil - } - ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error { - return nil - } - - packPayload := fleet.PackPayload{ - Name: ptr.String("foo"), - HostIDs: &[]uint{123}, - LabelIDs: &[]uint{456}, - TeamIDs: &[]uint{789}, - } - pack, _ := svc.NewPack(test.UserContext(test.UserAdmin), packPayload) - - require.Len(t, pack.HostIDs, 1) - require.Len(t, pack.LabelIDs, 1) - require.Len(t, pack.TeamIDs, 1) - assert.Equal(t, uint(123), pack.HostIDs[0]) - assert.Equal(t, uint(456), pack.LabelIDs[0]) - assert.Equal(t, uint(789), pack.TeamIDs[0]) -} - -func TestService_ModifyPack_GlobalPack(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - svc := newTestService(ds, nil, nil) - test.AddAllHostsLabel(t, ds) - users := createTestUsers(t, ds) - - globalPack, err := ds.EnsureGlobalPack(context.Background()) - require.NoError(t, err) - - labelids := []uint{1, 2, 3} - hostids := []uint{4, 5, 6} - teamids := []uint{7, 8, 9} - packPayload := fleet.PackPayload{ - Name: ptr.String("foo"), - Description: ptr.String("bar"), - LabelIDs: &labelids, - HostIDs: &hostids, - TeamIDs: &teamids, - } - - user := users["admin1@example.com"] - pack, _ := svc.ModifyPack(test.UserContext(&user), globalPack.ID, packPayload) - - require.Equal(t, "Global", pack.Name, "name for global pack should not change") - require.Equal(t, "Global pack", pack.Description, "description for global pack should not change") - require.Len(t, pack.LabelIDs, 1) - require.Len(t, pack.HostIDs, 0) - require.Len(t, pack.TeamIDs, 0) -} - -func TestService_DeletePackByID_GlobalPack(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - test.AddAllHostsLabel(t, ds) - - globalPack, err := ds.EnsureGlobalPack(context.Background()) - require.NoError(t, err) - - type fields struct { - ds fleet.Datastore - } - type args struct { - ctx context.Context - id uint - } - tests := []struct { - name string - fields fields - args args - wantErr bool - }{ - { - name: "cannot delete global pack", - fields: fields{ - ds, - }, - args: args{ - ctx: test.UserContext(test.UserAdmin), - id: globalPack.ID, - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - svc := newTestService(tt.fields.ds, nil, nil) - if err := svc.DeletePackByID(tt.args.ctx, tt.args.id); (err != nil) != tt.wantErr { - t.Errorf("DeletePackByID() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestService_ApplyPackSpecs(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - test.AddAllHostsLabel(t, ds) - - global, err := ds.EnsureGlobalPack(context.Background()) - require.NoError(t, err) - - users := createTestUsers(t, ds) - user := users["admin1@example.com"] - - team1, err := ds.NewTeam(context.Background(), &fleet.Team{ - ID: 42, - Name: "team1", - Description: "desc team1", - }) - require.NoError(t, err) - - teamPack, err := ds.EnsureTeamPack(context.Background(), team1.ID) - require.NoError(t, err) - - type fields struct { - ds fleet.Datastore - } - type args struct { - ctx context.Context - specs []*fleet.PackSpec - } - tests := []struct { - name string - fields fields - args args - want []*fleet.PackSpec - wantErr bool - }{ - { - name: "cannot modify global pack", - fields: fields{ - ds, - }, - args: args{ - ctx: test.UserContext(&user), - specs: []*fleet.PackSpec{ - {Name: global.Name, Description: "bar", Platform: "baz"}, - {Name: "Foo Pack", Description: "Foo Desc", Platform: "MacOS"}, - {Name: "Bar Pack", Description: "Bar Desc", Platform: "MacOS"}, - }, - }, - want: []*fleet.PackSpec{ - {Name: "Foo Pack", Description: "Foo Desc", Platform: "MacOS"}, - {Name: "Bar Pack", Description: "Bar Desc", Platform: "MacOS"}, - }, - wantErr: false, - }, - { - name: "cannot modify team pack", - fields: fields{ - ds, - }, - args: args{ - ctx: test.UserContext(&user), - specs: []*fleet.PackSpec{ - {Name: teamPack.Name, Description: "Desc", Platform: "windows"}, - {Name: "Test", Description: "Test Desc", Platform: "linux"}, - }, - }, - want: []*fleet.PackSpec{ - {Name: "Test", Description: "Test Desc", Platform: "linux"}, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - svc := newTestService(tt.fields.ds, nil, nil) - got, err := svc.ApplyPackSpecs(tt.args.ctx, tt.args.specs) - if (err != nil) != tt.wantErr { - t.Errorf("ApplyPackSpecs() error = %v, wantErr %v", err, tt.wantErr) - return - } - require.Equal(t, tt.want, got) - }) - } -} - -func TestService_DeletePack(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - test.AddAllHostsLabel(t, ds) - - gp, err := ds.EnsureGlobalPack(context.Background()) - require.NoError(t, err) - - users := createTestUsers(t, ds) - user := users["admin1@example.com"] - - team1, err := ds.NewTeam(context.Background(), &fleet.Team{ - ID: 42, - Name: "team1", - Description: "desc team1", - }) - require.NoError(t, err) - - tp, err := ds.EnsureTeamPack(context.Background(), team1.ID) - require.NoError(t, err) - - type fields struct { - ds fleet.Datastore - } - type args struct { - ctx context.Context - name string - } - tests := []struct { - name string - fields fields - args args - wantErr bool - }{ - { - name: "cannot delete global pack", - fields: fields{ - ds: ds, - }, - args: args{ - ctx: test.UserContext(&user), - name: gp.Name, - }, - wantErr: true, - }, - { - name: "cannot delete team pack", - fields: fields{ - ds: ds, - }, - args: args{ - ctx: test.UserContext(&user), - name: tp.Name, - }, - wantErr: true, - }, - { - name: "delete pack that doesn't exist", - fields: fields{ - ds: ds, - }, - args: args{ - ctx: test.UserContext(&user), - name: "foo", - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - svc := newTestService(tt.fields.ds, nil, nil) - if err := svc.DeletePack(tt.args.ctx, tt.args.name); (err != nil) != tt.wantErr { - t.Errorf("DeletePack() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} diff --git a/server/service/transport.go b/server/service/transport.go index 4cc301b180..3091f6f09c 100644 --- a/server/service/transport.go +++ b/server/service/transport.go @@ -299,7 +299,7 @@ func decodeNoParamsRequest(ctx context.Context, r *http.Request) (interface{}, e } type getGenericSpecRequest struct { - Name string + Name string `url:"name"` } func decodeGetGenericSpecRequest(ctx context.Context, r *http.Request) (interface{}, error) { diff --git a/server/service/transport_packs.go b/server/service/transport_packs.go deleted file mode 100644 index 0901d703ff..0000000000 --- a/server/service/transport_packs.go +++ /dev/null @@ -1,76 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "net/http" -) - -func decodeCreatePackRequest(ctx context.Context, r *http.Request) (interface{}, error) { - var req createPackRequest - if err := json.NewDecoder(r.Body).Decode(&req.payload); err != nil { - return nil, err - } - - return req, nil -} - -func decodeModifyPackRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := uintFromRequest(r, "id") - if err != nil { - return nil, err - } - var req modifyPackRequest - if err := json.NewDecoder(r.Body).Decode(&req.payload); err != nil { - return nil, err - } - req.ID = uint(id) - return req, nil -} - -func decodeDeletePackRequest(ctx context.Context, r *http.Request) (interface{}, error) { - name, err := stringFromRequest(r, "name") - if err != nil { - return nil, err - } - var req deletePackRequest - req.Name = name - return req, nil -} - -func decodeDeletePackByIDRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := uintFromRequest(r, "id") - if err != nil { - return nil, err - } - var req deletePackByIDRequest - req.ID = uint(id) - return req, nil -} - -func decodeGetPackRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := uintFromRequest(r, "id") - if err != nil { - return nil, err - } - var req getPackRequest - req.ID = uint(id) - return req, nil -} - -func decodeListPacksRequest(ctx context.Context, r *http.Request) (interface{}, error) { - opt, err := listOptionsFromRequest(r) - if err != nil { - return nil, err - } - return listPacksRequest{ListOptions: opt}, nil -} - -func decodeApplyPackSpecsRequest(ctx context.Context, r *http.Request) (interface{}, error) { - var req applyPackSpecsRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, err - } - return req, nil - -} diff --git a/server/service/transport_packs_test.go b/server/service/transport_packs_test.go deleted file mode 100644 index 97569a2901..0000000000 --- a/server/service/transport_packs_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package service - -import ( - "bytes" - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gorilla/mux" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestDecodeCreatePackRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/api/v1/fleet/packs", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeCreatePackRequest(context.Background(), request) - assert.Nil(t, err) - - params := r.(createPackRequest) - assert.Equal(t, "foo", *params.payload.Name) - assert.Equal(t, "bar", *params.payload.Description) - require.NotNil(t, params.payload.HostIDs) - assert.Len(t, *params.payload.HostIDs, 3) - require.NotNil(t, params.payload.LabelIDs) - assert.Len(t, *params.payload.LabelIDs, 2) - }).Methods("POST") - - var body bytes.Buffer - body.Write([]byte(`{ - "name": "foo", - "description": "bar", - "host_ids": [1, 2, 3], - "label_ids": [1, 5] - }`)) - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("POST", "/api/v1/fleet/packs", &body), - ) -} - -func TestDecodeModifyPackRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/api/v1/fleet/packs/{id}", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeModifyPackRequest(context.Background(), request) - assert.Nil(t, err) - - params := r.(modifyPackRequest) - assert.Equal(t, uint(1), params.ID) - assert.Equal(t, "foo", *params.payload.Name) - assert.Equal(t, "bar", *params.payload.Description) - require.NotNil(t, params.payload.HostIDs) - assert.Len(t, *params.payload.HostIDs, 3) - require.NotNil(t, params.payload.LabelIDs) - assert.Len(t, *params.payload.LabelIDs, 2) - }).Methods("PATCH") - - var body bytes.Buffer - body.Write([]byte(`{ - "name": "foo", - "description": "bar", - "host_ids": [1, 2, 3], - "label_ids": [1, 5] - }`)) - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("PATCH", "/api/v1/fleet/packs/1", &body), - ) -} - -func TestDecodeDeletePackRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/api/v1/fleet/packs/{name}", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeDeletePackRequest(context.Background(), request) - assert.Nil(t, err) - - params := r.(deletePackRequest) - assert.Equal(t, "packaday", params.Name) - }).Methods("DELETE") - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("DELETE", "/api/v1/fleet/packs/packaday", nil), - ) -} - -func TestDecodeGetPackRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/api/v1/fleet/packs/{id}", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeGetPackRequest(context.Background(), request) - assert.Nil(t, err) - - params := r.(getPackRequest) - assert.Equal(t, uint(1), params.ID) - }).Methods("GET") - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("GET", "/api/v1/fleet/packs/1", nil), - ) -}