diff --git a/server/service/endpoint_packs.go b/server/service/endpoint_packs.go index 0ddda645a7..7a12a91522 100644 --- a/server/service/endpoint_packs.go +++ b/server/service/endpoint_packs.go @@ -47,41 +47,6 @@ func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack }, nil } -//////////////////////////////////////////////////////////////////////////////// -// Get Pack -//////////////////////////////////////////////////////////////////////////////// - -type getPackRequest struct { - ID uint -} - -type getPackResponse struct { - Pack packResponse `json:"pack,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r getPackResponse) error() error { return r.Err } - -func makeGetPackEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(getPackRequest) - - pack, err := svc.GetPack(ctx, req.ID) - if err != nil { - return getPackResponse{Err: err}, nil - } - - resp, err := packResponseForPack(ctx, svc, *pack) - if err != nil { - return getPackResponse{Err: err}, nil - } - - return getPackResponse{ - Pack: *resp, - }, nil - } -} - //////////////////////////////////////////////////////////////////////////////// // List Packs //////////////////////////////////////////////////////////////////////////////// diff --git a/server/service/handler.go b/server/service/handler.go index 93c7ad1fdf..bd9220b215 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -61,7 +61,6 @@ type FleetEndpoints struct { CreateDistributedQueryCampaignByNames endpoint.Endpoint CreatePack endpoint.Endpoint ModifyPack endpoint.Endpoint - GetPack endpoint.Endpoint ListPacks endpoint.Endpoint DeletePack endpoint.Endpoint DeletePackByID endpoint.Endpoint @@ -184,7 +183,6 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th CreateDistributedQueryCampaignByNames: authenticatedUser(svc, makeCreateDistributedQueryCampaignByNamesEndpoint(svc)), CreatePack: authenticatedUser(svc, makeCreatePackEndpoint(svc)), ModifyPack: authenticatedUser(svc, makeModifyPackEndpoint(svc)), - GetPack: authenticatedUser(svc, makeGetPackEndpoint(svc)), ListPacks: authenticatedUser(svc, makeListPacksEndpoint(svc)), DeletePack: authenticatedUser(svc, makeDeletePackEndpoint(svc)), DeletePackByID: authenticatedUser(svc, makeDeletePackByIDEndpoint(svc)), @@ -295,7 +293,6 @@ type fleetHandlers struct { CreateDistributedQueryCampaignByNames http.Handler CreatePack http.Handler ModifyPack http.Handler - GetPack http.Handler ListPacks http.Handler DeletePack http.Handler DeletePackByID http.Handler @@ -405,7 +402,6 @@ func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandle CreateDistributedQueryCampaignByNames: newServer(e.CreateDistributedQueryCampaignByNames, decodeCreateDistributedQueryCampaignByNamesRequest), CreatePack: newServer(e.CreatePack, decodeCreatePackRequest), ModifyPack: newServer(e.ModifyPack, decodeModifyPackRequest), - GetPack: newServer(e.GetPack, decodeGetPackRequest), ListPacks: newServer(e.ListPacks, decodeListPacksRequest), DeletePack: newServer(e.DeletePack, decodeDeletePackRequest), DeletePackByID: newServer(e.DeletePackByID, decodeDeletePackByIDRequest), @@ -615,7 +611,6 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) { r.Handle("/api/v1/fleet/packs", h.CreatePack).Methods("POST").Name("create_pack") r.Handle("/api/v1/fleet/packs/{id}", h.ModifyPack).Methods("PATCH").Name("modify_pack") - r.Handle("/api/v1/fleet/packs/{id}", h.GetPack).Methods("GET").Name("get_pack") r.Handle("/api/v1/fleet/packs", h.ListPacks).Methods("GET").Name("list_packs") r.Handle("/api/v1/fleet/packs/{name}", h.DeletePack).Methods("DELETE").Name("delete_pack") r.Handle("/api/v1/fleet/packs/id/{id}", h.DeletePackByID).Methods("DELETE").Name("delete_pack_by_id") @@ -719,6 +714,8 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht e.GET("/api/v1/fleet/teams/{team_id}/policies/{policy_id}", getTeamPolicyByIDEndpoint, getTeamPolicyByIDRequest{}) e.POST("/api/v1/fleet/teams/{team_id}/policies/delete", deleteTeamPoliciesEndpoint, deleteTeamPoliciesRequest{}) + e.GET("/api/v1/fleet/packs/{id:[0-9]+}", getPackEndpoint, getPackRequest{}) + e.GET("/api/v1/fleet/software", listSoftwareEndpoint, listSoftwareRequest{}) e.POST("/api/v1/fleet/hosts/delete", deleteHostsEndpoint, deleteHostsRequest{}) diff --git a/server/service/handler_test.go b/server/service/handler_test.go index a8315209a5..b4cfa0d77f 100644 --- a/server/service/handler_test.go +++ b/server/service/handler_test.go @@ -2,13 +2,18 @@ package service import ( "fmt" + "net/http" "net/http/httptest" + "regexp" "testing" + "github.com/fleetdm/fleet/v4/server/config" "github.com/fleetdm/fleet/v4/server/mock" kitlog "github.com/go-kit/kit/log" "github.com/gorilla/mux" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/throttled/throttled/v2/store/memstore" ) @@ -109,10 +114,6 @@ func TestAPIRoutes(t *testing.T) { verb: "POST", uri: "/api/v1/fleet/queries/run", }, - { - verb: "GET", - uri: "/api/v1/fleet/packs/1", - }, { verb: "GET", uri: "/api/v1/fleet/packs", @@ -180,10 +181,6 @@ func TestAPIRoutes(t *testing.T) { verb: "DELETE", uri: "/api/v1/fleet/labels/1", }, - { - verb: "GET", - uri: "/api/v1/fleet/hosts/1", - }, { verb: "GET", uri: "/api/v1/fleet/hosts", @@ -206,6 +203,77 @@ func TestAPIRoutes(t *testing.T) { httptest.NewRequest(route.verb, route.uri, nil), ) assert.NotEqual(st, 404, recorder.Code) + assert.NotEqual(st, 405, recorder.Code, route.verb) // if it matches a path but with wrong verb + }) + } +} + +func TestAPIRoutesConflicts(t *testing.T) { + ds := new(mock.Store) + + svc := newTestService(ds, nil, nil) + limitStore, _ := memstore.New(0) + h := MakeHandler(svc, config.TestConfig(), kitlog.NewNopLogger(), limitStore) + router := h.(*mux.Router) + + type testCase struct { + name string + path string + verb string + want int + } + var cases []testCase + + // build the test cases: for each route, generate a request designed to match + // it, and override its handler to return a unique status code. If the + // request doesn't result in that status code, then some other route + // conflicts with it and took precedence - a route conflict. The route's name + // is used to name the sub-test for that route. + status := 200 + reSimpleVar, reNumVar := regexp.MustCompile(`\{(\w+)\}`), regexp.MustCompile(`\{\w+:[^\}]+\}`) + err := router.Walk(func(route *mux.Route, router *mux.Router, ancestores []*mux.Route) error { + name := route.GetName() + path, err := route.GetPathTemplate() + if err != nil { + // all our routes should have paths + return errors.Wrap(err, name) + } + meths, err := route.GetMethods() + if err != nil || len(meths) == 0 { + // only route without method is distributed_query_results (websocket) + if name != "distributed_query_results" { + return errors.Wrap(err, name+" "+path) + } + return nil + } + path = reSimpleVar.ReplaceAllString(path, "$1") + // for now at least, the only times we use regexp-constrained vars is + // for numeric arguments. + path = reNumVar.ReplaceAllString(path, "1") + + routeStatus := status + route.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(routeStatus) }) + for _, meth := range meths { + cases = append(cases, testCase{ + name: name, + path: path, + verb: meth, + want: status, + }) + } + + status++ + return nil + }) + require.NoError(t, err) + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + t.Log(c.verb, c.path) + req := httptest.NewRequest(c.verb, c.path, nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + require.Equal(t, c.want, rr.Code) }) } } diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 5a073a06c8..56ebb1499b 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -563,3 +563,19 @@ func (s *integrationTestSuite) TestCountSoftware() { ) assert.Equal(t, 1, resp.Count) } + +func (s *integrationTestSuite) TestGetPack() { + t := s.T() + + pack := &fleet.Pack{ + Name: t.Name(), + } + pack, err := s.ds.NewPack(context.Background(), pack) + require.NoError(t, err) + + var packResp getPackResponse + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/packs/%d", pack.ID), nil, http.StatusOK, &packResp) + require.Equal(t, packResp.Pack.ID, pack.ID) + + s.Do("GET", fmt.Sprintf("/api/v1/fleet/packs/%d", pack.ID+1), nil, http.StatusNotFound) +} diff --git a/server/service/packs.go b/server/service/packs.go new file mode 100644 index 0000000000..310cf60520 --- /dev/null +++ b/server/service/packs.go @@ -0,0 +1,47 @@ +package service + +import ( + "context" + + "github.com/fleetdm/fleet/v4/server/fleet" +) + +//////////////////////////////////////////////////////////////////////////////// +// Get Pack +//////////////////////////////////////////////////////////////////////////////// + +type getPackRequest struct { + ID uint `url:"id"` +} + +type getPackResponse struct { + Pack packResponse `json:"pack,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r getPackResponse) error() error { return r.Err } + +func getPackEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*getPackRequest) + pack, err := svc.GetPack(ctx, req.ID) + if err != nil { + return getPackResponse{Err: err}, nil + } + + resp, err := packResponseForPack(ctx, svc, *pack) + if err != nil { + return getPackResponse{Err: err}, nil + } + + return getPackResponse{ + Pack: *resp, + }, nil +} + +func (svc *Service) GetPack(ctx context.Context, id uint) (*fleet.Pack, error) { + if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { + return nil, err + } + + return svc.ds.Pack(ctx, id) +} diff --git a/server/service/packs_test.go b/server/service/packs_test.go new file mode 100644 index 0000000000..69bf4da2e8 --- /dev/null +++ b/server/service/packs_test.go @@ -0,0 +1,32 @@ +package service + +import ( + "context" + "testing" + + "github.com/fleetdm/fleet/v4/server/authz" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/mock" + "github.com/fleetdm/fleet/v4/server/test" + "github.com/stretchr/testify/require" +) + +func TestGetPack(t *testing.T) { + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + + ds.PackFunc = func(ctx context.Context, id uint) (*fleet.Pack, error) { + return &fleet.Pack{ + ID: 1, + TeamIDs: []uint{1}, + }, nil + } + + pack, err := svc.GetPack(test.UserContext(test.UserAdmin), 1) + require.NoError(t, err) + require.Equal(t, uint(1), pack.ID) + + _, err = svc.GetPack(test.UserContext(test.UserNoRoles), 1) + require.Error(t, err) + require.Contains(t, err.Error(), authz.ForbiddenErrorMessage) +} diff --git a/server/service/service_packs.go b/server/service/service_packs.go index 84affb4b47..a6a19213fb 100644 --- a/server/service/service_packs.go +++ b/server/service/service_packs.go @@ -76,14 +76,6 @@ func (svc *Service) ListPacks(ctx context.Context, opt fleet.PackListOptions) ([ return svc.ds.ListPacks(ctx, opt) } -func (svc *Service) GetPack(ctx context.Context, id uint) (*fleet.Pack, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { - return nil, err - } - - return svc.ds.Pack(ctx, id) -} - func (svc *Service) NewPack(ctx context.Context, p fleet.PackPayload) (*fleet.Pack, error) { if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { return nil, err diff --git a/server/service/service_packs_test.go b/server/service/service_packs_test.go index 7bf1e66d2e..5e5d718a1b 100644 --- a/server/service/service_packs_test.go +++ b/server/service/service_packs_test.go @@ -33,25 +33,6 @@ func TestServiceListPacks(t *testing.T) { assert.Len(t, queries, 1) } -func TestGetPack(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - svc := newTestService(ds, nil, nil) - - pack := &fleet.Pack{ - Name: "foo", - } - _, err := ds.NewPack(context.Background(), pack) - assert.Nil(t, err) - assert.NotZero(t, pack.ID) - - packVerify, err := svc.GetPack(test.UserContext(test.UserAdmin), pack.ID) - assert.Nil(t, err) - - assert.Equal(t, pack.ID, packVerify.ID) -} - func TestNewSavesTargets(t *testing.T) { ds := new(mock.Store) svc := newTestService(ds, nil, nil)