Refactor ListHosts to new endpoint pattern (#2396)

This commit is contained in:
Martin Angers 2021-10-11 10:37:48 -04:00 committed by GitHub
parent fce3e42abb
commit 5e1f872ccb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 100 additions and 97 deletions

View file

@ -84,42 +84,6 @@ func makeHostByIdentifierEndpoint(svc fleet.Service) endpoint.Endpoint {
}
}
////////////////////////////////////////////////////////////////////////////////
// List Hosts
////////////////////////////////////////////////////////////////////////////////
type listHostsRequest struct {
ListOptions fleet.HostListOptions
}
type listHostsResponse struct {
Hosts []HostResponse `json:"hosts"`
Err error `json:"error,omitempty"`
}
func (r listHostsResponse) error() error { return r.Err }
func makeListHostsEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(listHostsRequest)
hosts, err := svc.ListHosts(ctx, req.ListOptions)
if err != nil {
return listHostsResponse{Err: err}, nil
}
hostResponses := make([]HostResponse, len(hosts))
for i, host := range hosts {
h, err := hostResponseForHost(ctx, svc, host)
if err != nil {
return listHostsResponse{Err: err}, nil
}
hostResponses[i] = *h
}
return listHostsResponse{Hosts: hostResponses}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Get Host Summary
////////////////////////////////////////////////////////////////////////////////

View file

@ -96,7 +96,6 @@ type FleetEndpoints struct {
HostByIdentifier endpoint.Endpoint
DeleteHost endpoint.Endpoint
RefetchHost endpoint.Endpoint
ListHosts endpoint.Endpoint
GetHostSummary endpoint.Endpoint
AddHostsToTeam endpoint.Endpoint
AddHostsToTeamByFilter endpoint.Endpoint
@ -199,7 +198,6 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th
ModifyGlobalSchedule: authenticatedUser(svc, makeModifyGlobalScheduleEndpoint(svc)),
DeleteGlobalSchedule: authenticatedUser(svc, makeDeleteGlobalScheduleEndpoint(svc)),
HostByIdentifier: authenticatedUser(svc, makeHostByIdentifierEndpoint(svc)),
ListHosts: authenticatedUser(svc, makeListHostsEndpoint(svc)),
GetHostSummary: authenticatedUser(svc, makeGetHostSummaryEndpoint(svc)),
DeleteHost: authenticatedUser(svc, makeDeleteHostEndpoint(svc)),
AddHostsToTeam: authenticatedUser(svc, makeAddHostsToTeamEndpoint(svc)),
@ -328,7 +326,6 @@ type fleetHandlers struct {
HostByIdentifier http.Handler
DeleteHost http.Handler
RefetchHost http.Handler
ListHosts http.Handler
GetHostSummary http.Handler
AddHostsToTeam http.Handler
AddHostsToTeamByFilter http.Handler
@ -437,7 +434,6 @@ func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandle
HostByIdentifier: newServer(e.HostByIdentifier, decodeHostByIdentifierRequest),
DeleteHost: newServer(e.DeleteHost, decodeDeleteHostRequest),
RefetchHost: newServer(e.RefetchHost, decodeRefetchHostRequest),
ListHosts: newServer(e.ListHosts, decodeListHostsRequest),
GetHostSummary: newServer(e.GetHostSummary, decodeNoParamsRequest),
AddHostsToTeam: newServer(e.AddHostsToTeam, decodeAddHostsToTeamRequest),
AddHostsToTeamByFilter: newServer(e.AddHostsToTeamByFilter, decodeAddHostsToTeamByFilterRequest),
@ -639,7 +635,6 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
r.Handle("/api/v1/fleet/spec/labels", h.GetLabelSpecs).Methods("GET").Name("get_label_specs")
r.Handle("/api/v1/fleet/spec/labels/{name}", h.GetLabelSpec).Methods("GET").Name("get_label_spec")
r.Handle("/api/v1/fleet/hosts", h.ListHosts).Methods("GET").Name("list_hosts")
r.Handle("/api/v1/fleet/host_summary", h.GetHostSummary).Methods("GET").Name("get_host_summary")
r.Handle("/api/v1/fleet/hosts/identifier/{identifier}", h.HostByIdentifier).Methods("GET").Name("host_by_identifier")
r.Handle("/api/v1/fleet/hosts/{id}", h.DeleteHost).Methods("DELETE").Name("delete_host")
@ -718,6 +713,7 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht
e.GET("/api/v1/fleet/software", listSoftwareEndpoint, listSoftwareRequest{})
e.GET("/api/v1/fleet/hosts", listHostsEndpoint, listHostsRequest{})
e.POST("/api/v1/fleet/hosts/delete", deleteHostsEndpoint, deleteHostsRequest{})
e.GET("/api/v1/fleet/hosts/{id:[0-9]+}", getHostEndpoint, getHostRequest{})
e.GET("/api/v1/fleet/hosts/count", countHostsEndpoint, countHostsRequest{})

View file

@ -181,10 +181,6 @@ func TestAPIRoutes(t *testing.T) {
verb: "DELETE",
uri: "/api/v1/fleet/labels/1",
},
{
verb: "GET",
uri: "/api/v1/fleet/hosts",
},
{
verb: "DELETE",
uri: "/api/v1/fleet/hosts/1",

View file

@ -3,10 +3,59 @@ package service
import (
"context"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/pkg/errors"
)
////////////////////////////////////////////////////////////////////////////////
// List Hosts
////////////////////////////////////////////////////////////////////////////////
type listHostsRequest struct {
Opts fleet.HostListOptions `url:"host_options"`
}
type listHostsResponse struct {
Hosts []HostResponse `json:"hosts"`
Err error `json:"error,omitempty"`
}
func (r listHostsResponse) error() error { return r.Err }
func listHostsEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*listHostsRequest)
hosts, err := svc.ListHosts(ctx, req.Opts)
if err != nil {
return listHostsResponse{Err: err}, nil
}
hostResponses := make([]HostResponse, len(hosts))
for i, host := range hosts {
h, err := hostResponseForHost(ctx, svc, host)
if err != nil {
return listHostsResponse{Err: err}, nil
}
hostResponses[i] = *h
}
return listHostsResponse{Hosts: hostResponses}, nil
}
func (svc Service) ListHosts(ctx context.Context, opt fleet.HostListOptions) ([]*fleet.Host, error) {
if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil {
return nil, err
}
vc, ok := viewer.FromContext(ctx)
if !ok {
return nil, fleet.ErrNoContext
}
filter := fleet.TeamFilter{User: vc.User, IncludeObserver: true}
return svc.ds.ListHosts(ctx, filter, opt)
}
/////////////////////////////////////////////////////////////////////////////////
// Delete
/////////////////////////////////////////////////////////////////////////////////

View file

@ -0,0 +1,37 @@
package service
import (
"context"
"testing"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/require"
)
func TestListHosts(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
return []*fleet.Host{
{ID: 1},
}, nil
}
hosts, err := svc.ListHosts(test.UserContext(test.UserAdmin), fleet.HostListOptions{})
require.NoError(t, err)
require.Len(t, hosts, 1)
// anyone can list hosts
hosts, err = svc.ListHosts(test.UserContext(test.UserNoRoles), fleet.HostListOptions{})
require.NoError(t, err)
require.Len(t, hosts, 1)
// a user is required
_, err = svc.ListHosts(context.Background(), fleet.HostListOptions{})
require.Error(t, err)
require.Contains(t, err.Error(), authz.ForbiddenErrorMessage)
}

View file

@ -579,3 +579,16 @@ func (s *integrationTestSuite) TestGetPack() {
s.Do("GET", fmt.Sprintf("/api/v1/fleet/packs/%d", pack.ID+1), nil, http.StatusNotFound)
}
func (s *integrationTestSuite) TestListHosts() {
t := s.T()
hosts := s.createHosts(t)
var resp listHostsResponse
s.DoJSON("GET", "/api/v1/fleet/hosts", nil, http.StatusOK, &resp)
require.Len(t, resp.Hosts, len(hosts))
s.DoJSON("GET", "/api/v1/fleet/hosts", nil, http.StatusOK, &resp, "per_page", "1")
require.Len(t, resp.Hosts, 1)
}

View file

@ -8,20 +8,6 @@ import (
"github.com/pkg/errors"
)
func (svc Service) ListHosts(ctx context.Context, opt fleet.HostListOptions) ([]*fleet.Host, error) {
if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil {
return nil, err
}
vc, ok := viewer.FromContext(ctx)
if !ok {
return nil, fleet.ErrNoContext
}
filter := fleet.TeamFilter{User: vc.User, IncludeObserver: true}
return svc.ds.ListHosts(ctx, filter, opt)
}
func (svc Service) GetHost(ctx context.Context, id uint) (*fleet.HostDetail, error) {
// First ensure the user has access to list hosts, then check the specific
// host once team_id is loaded.

View file

@ -3,7 +3,6 @@ package service
import (
"context"
"testing"
"time"
"github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
@ -16,34 +15,6 @@ import (
"github.com/stretchr/testify/require"
)
func TestListHosts(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer ds.Close()
svc := newTestService(ds, nil, nil)
hosts, err := svc.ListHosts(test.UserContext(test.UserAdmin), fleet.HostListOptions{})
assert.Nil(t, err)
assert.Len(t, hosts, 0)
storedTime := time.Now().UTC()
_, err = ds.NewHost(context.Background(), &fleet.Host{
Hostname: "foo",
SeenTime: storedTime,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
})
require.NoError(t, err)
hosts, err = svc.ListHosts(test.UserContext(test.UserAdmin), fleet.HostListOptions{})
require.NoError(t, err)
require.Len(t, hosts, 1)
format := "%Y-%m-%d %HH:%MM:%SS %Z"
assert.Equal(t, storedTime.Format(format), hosts[0].SeenTime.Format(format))
}
func TestDeleteHost(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer ds.Close()

View file

@ -30,15 +30,6 @@ func decodeRefetchHostRequest(ctx context.Context, r *http.Request) (interface{}
return refetchHostRequest{ID: id}, nil
}
func decodeListHostsRequest(ctx context.Context, r *http.Request) (interface{}, error) {
hopt, err := hostListOptionsFromRequest(r)
if err != nil {
return nil, err
}
return listHostsRequest{ListOptions: hopt}, nil
}
func decodeAddHostsToTeamRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var req addHostsToTeamRequest