Refactor GetPack to new endpoint pattern (#2409)

This commit is contained in:
Martin Angers 2021-10-11 10:17:21 -04:00 committed by GitHub
parent 7f3d3ad96c
commit fce3e42abb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 173 additions and 75 deletions

View file

@ -47,41 +47,6 @@ func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack
}, nil
}
////////////////////////////////////////////////////////////////////////////////
// Get Pack
////////////////////////////////////////////////////////////////////////////////
type getPackRequest struct {
ID uint
}
type getPackResponse struct {
Pack packResponse `json:"pack,omitempty"`
Err error `json:"error,omitempty"`
}
func (r getPackResponse) error() error { return r.Err }
func makeGetPackEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(getPackRequest)
pack, err := svc.GetPack(ctx, req.ID)
if err != nil {
return getPackResponse{Err: err}, nil
}
resp, err := packResponseForPack(ctx, svc, *pack)
if err != nil {
return getPackResponse{Err: err}, nil
}
return getPackResponse{
Pack: *resp,
}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// List Packs
////////////////////////////////////////////////////////////////////////////////

View file

@ -61,7 +61,6 @@ type FleetEndpoints struct {
CreateDistributedQueryCampaignByNames endpoint.Endpoint
CreatePack endpoint.Endpoint
ModifyPack endpoint.Endpoint
GetPack endpoint.Endpoint
ListPacks endpoint.Endpoint
DeletePack endpoint.Endpoint
DeletePackByID endpoint.Endpoint
@ -184,7 +183,6 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th
CreateDistributedQueryCampaignByNames: authenticatedUser(svc, makeCreateDistributedQueryCampaignByNamesEndpoint(svc)),
CreatePack: authenticatedUser(svc, makeCreatePackEndpoint(svc)),
ModifyPack: authenticatedUser(svc, makeModifyPackEndpoint(svc)),
GetPack: authenticatedUser(svc, makeGetPackEndpoint(svc)),
ListPacks: authenticatedUser(svc, makeListPacksEndpoint(svc)),
DeletePack: authenticatedUser(svc, makeDeletePackEndpoint(svc)),
DeletePackByID: authenticatedUser(svc, makeDeletePackByIDEndpoint(svc)),
@ -295,7 +293,6 @@ type fleetHandlers struct {
CreateDistributedQueryCampaignByNames http.Handler
CreatePack http.Handler
ModifyPack http.Handler
GetPack http.Handler
ListPacks http.Handler
DeletePack http.Handler
DeletePackByID http.Handler
@ -405,7 +402,6 @@ func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandle
CreateDistributedQueryCampaignByNames: newServer(e.CreateDistributedQueryCampaignByNames, decodeCreateDistributedQueryCampaignByNamesRequest),
CreatePack: newServer(e.CreatePack, decodeCreatePackRequest),
ModifyPack: newServer(e.ModifyPack, decodeModifyPackRequest),
GetPack: newServer(e.GetPack, decodeGetPackRequest),
ListPacks: newServer(e.ListPacks, decodeListPacksRequest),
DeletePack: newServer(e.DeletePack, decodeDeletePackRequest),
DeletePackByID: newServer(e.DeletePackByID, decodeDeletePackByIDRequest),
@ -615,7 +611,6 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
r.Handle("/api/v1/fleet/packs", h.CreatePack).Methods("POST").Name("create_pack")
r.Handle("/api/v1/fleet/packs/{id}", h.ModifyPack).Methods("PATCH").Name("modify_pack")
r.Handle("/api/v1/fleet/packs/{id}", h.GetPack).Methods("GET").Name("get_pack")
r.Handle("/api/v1/fleet/packs", h.ListPacks).Methods("GET").Name("list_packs")
r.Handle("/api/v1/fleet/packs/{name}", h.DeletePack).Methods("DELETE").Name("delete_pack")
r.Handle("/api/v1/fleet/packs/id/{id}", h.DeletePackByID).Methods("DELETE").Name("delete_pack_by_id")
@ -719,6 +714,8 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht
e.GET("/api/v1/fleet/teams/{team_id}/policies/{policy_id}", getTeamPolicyByIDEndpoint, getTeamPolicyByIDRequest{})
e.POST("/api/v1/fleet/teams/{team_id}/policies/delete", deleteTeamPoliciesEndpoint, deleteTeamPoliciesRequest{})
e.GET("/api/v1/fleet/packs/{id:[0-9]+}", getPackEndpoint, getPackRequest{})
e.GET("/api/v1/fleet/software", listSoftwareEndpoint, listSoftwareRequest{})
e.POST("/api/v1/fleet/hosts/delete", deleteHostsEndpoint, deleteHostsRequest{})

View file

@ -2,13 +2,18 @@ package service
import (
"fmt"
"net/http"
"net/http/httptest"
"regexp"
"testing"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/mock"
kitlog "github.com/go-kit/kit/log"
"github.com/gorilla/mux"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/throttled/throttled/v2/store/memstore"
)
@ -109,10 +114,6 @@ func TestAPIRoutes(t *testing.T) {
verb: "POST",
uri: "/api/v1/fleet/queries/run",
},
{
verb: "GET",
uri: "/api/v1/fleet/packs/1",
},
{
verb: "GET",
uri: "/api/v1/fleet/packs",
@ -180,10 +181,6 @@ func TestAPIRoutes(t *testing.T) {
verb: "DELETE",
uri: "/api/v1/fleet/labels/1",
},
{
verb: "GET",
uri: "/api/v1/fleet/hosts/1",
},
{
verb: "GET",
uri: "/api/v1/fleet/hosts",
@ -206,6 +203,77 @@ func TestAPIRoutes(t *testing.T) {
httptest.NewRequest(route.verb, route.uri, nil),
)
assert.NotEqual(st, 404, recorder.Code)
assert.NotEqual(st, 405, recorder.Code, route.verb) // if it matches a path but with wrong verb
})
}
}
func TestAPIRoutesConflicts(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
limitStore, _ := memstore.New(0)
h := MakeHandler(svc, config.TestConfig(), kitlog.NewNopLogger(), limitStore)
router := h.(*mux.Router)
type testCase struct {
name string
path string
verb string
want int
}
var cases []testCase
// build the test cases: for each route, generate a request designed to match
// it, and override its handler to return a unique status code. If the
// request doesn't result in that status code, then some other route
// conflicts with it and took precedence - a route conflict. The route's name
// is used to name the sub-test for that route.
status := 200
reSimpleVar, reNumVar := regexp.MustCompile(`\{(\w+)\}`), regexp.MustCompile(`\{\w+:[^\}]+\}`)
err := router.Walk(func(route *mux.Route, router *mux.Router, ancestores []*mux.Route) error {
name := route.GetName()
path, err := route.GetPathTemplate()
if err != nil {
// all our routes should have paths
return errors.Wrap(err, name)
}
meths, err := route.GetMethods()
if err != nil || len(meths) == 0 {
// only route without method is distributed_query_results (websocket)
if name != "distributed_query_results" {
return errors.Wrap(err, name+" "+path)
}
return nil
}
path = reSimpleVar.ReplaceAllString(path, "$1")
// for now at least, the only times we use regexp-constrained vars is
// for numeric arguments.
path = reNumVar.ReplaceAllString(path, "1")
routeStatus := status
route.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(routeStatus) })
for _, meth := range meths {
cases = append(cases, testCase{
name: name,
path: path,
verb: meth,
want: status,
})
}
status++
return nil
})
require.NoError(t, err)
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
t.Log(c.verb, c.path)
req := httptest.NewRequest(c.verb, c.path, nil)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
require.Equal(t, c.want, rr.Code)
})
}
}

View file

@ -563,3 +563,19 @@ func (s *integrationTestSuite) TestCountSoftware() {
)
assert.Equal(t, 1, resp.Count)
}
func (s *integrationTestSuite) TestGetPack() {
t := s.T()
pack := &fleet.Pack{
Name: t.Name(),
}
pack, err := s.ds.NewPack(context.Background(), pack)
require.NoError(t, err)
var packResp getPackResponse
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/packs/%d", pack.ID), nil, http.StatusOK, &packResp)
require.Equal(t, packResp.Pack.ID, pack.ID)
s.Do("GET", fmt.Sprintf("/api/v1/fleet/packs/%d", pack.ID+1), nil, http.StatusNotFound)
}

47
server/service/packs.go Normal file
View file

@ -0,0 +1,47 @@
package service
import (
"context"
"github.com/fleetdm/fleet/v4/server/fleet"
)
////////////////////////////////////////////////////////////////////////////////
// Get Pack
////////////////////////////////////////////////////////////////////////////////
type getPackRequest struct {
ID uint `url:"id"`
}
type getPackResponse struct {
Pack packResponse `json:"pack,omitempty"`
Err error `json:"error,omitempty"`
}
func (r getPackResponse) error() error { return r.Err }
func getPackEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*getPackRequest)
pack, err := svc.GetPack(ctx, req.ID)
if err != nil {
return getPackResponse{Err: err}, nil
}
resp, err := packResponseForPack(ctx, svc, *pack)
if err != nil {
return getPackResponse{Err: err}, nil
}
return getPackResponse{
Pack: *resp,
}, nil
}
func (svc *Service) GetPack(ctx context.Context, id uint) (*fleet.Pack, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.Pack(ctx, id)
}

View file

@ -0,0 +1,32 @@
package service
import (
"context"
"testing"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/require"
)
func TestGetPack(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
ds.PackFunc = func(ctx context.Context, id uint) (*fleet.Pack, error) {
return &fleet.Pack{
ID: 1,
TeamIDs: []uint{1},
}, nil
}
pack, err := svc.GetPack(test.UserContext(test.UserAdmin), 1)
require.NoError(t, err)
require.Equal(t, uint(1), pack.ID)
_, err = svc.GetPack(test.UserContext(test.UserNoRoles), 1)
require.Error(t, err)
require.Contains(t, err.Error(), authz.ForbiddenErrorMessage)
}

View file

@ -76,14 +76,6 @@ func (svc *Service) ListPacks(ctx context.Context, opt fleet.PackListOptions) ([
return svc.ds.ListPacks(ctx, opt)
}
func (svc *Service) GetPack(ctx context.Context, id uint) (*fleet.Pack, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.Pack(ctx, id)
}
func (svc *Service) NewPack(ctx context.Context, p fleet.PackPayload) (*fleet.Pack, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return nil, err

View file

@ -33,25 +33,6 @@ func TestServiceListPacks(t *testing.T) {
assert.Len(t, queries, 1)
}
func TestGetPack(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer ds.Close()
svc := newTestService(ds, nil, nil)
pack := &fleet.Pack{
Name: "foo",
}
_, err := ds.NewPack(context.Background(), pack)
assert.Nil(t, err)
assert.NotZero(t, pack.ID)
packVerify, err := svc.GetPack(test.UserContext(test.UserAdmin), pack.ID)
assert.Nil(t, err)
assert.Equal(t, pack.ID, packVerify.ID)
}
func TestNewSavesTargets(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)