diff --git a/server/service/hosts.go b/server/service/hosts.go index f1d2120d4c..c3c0bc9b19 100644 --- a/server/service/hosts.go +++ b/server/service/hosts.go @@ -4,6 +4,7 @@ import ( "context" "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/pkg/errors" ) ///////////////////////////////////////////////////////////////////////////////// @@ -43,7 +44,7 @@ func deleteHostsEndpoint(ctx context.Context, request interface{}, svc fleet.Ser } func (svc Service) DeleteHosts(ctx context.Context, ids []uint, opts fleet.HostListOptions, lid *uint) error { - if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionWrite); err != nil { + if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil { return err } @@ -52,6 +53,10 @@ func (svc Service) DeleteHosts(ctx context.Context, ids []uint, opts fleet.HostL } if len(ids) > 0 { + err := svc.checkWriteForHostIDs(ctx, ids) + if err != nil { + return err + } return svc.ds.DeleteHosts(ctx, ids) } @@ -63,5 +68,25 @@ func (svc Service) DeleteHosts(ctx context.Context, ids []uint, opts fleet.HostL if len(hostIDs) == 0 { return nil } + + err = svc.checkWriteForHostIDs(ctx, hostIDs) + if err != nil { + return err + } return svc.ds.DeleteHosts(ctx, hostIDs) } + +func (svc Service) checkWriteForHostIDs(ctx context.Context, ids []uint) error { + for _, id := range ids { + host, err := svc.ds.Host(ctx, id) + if err != nil { + return errors.Wrap(err, "get host for delete") + } + + // Authorize again with team loaded now that we have team_id + if err := svc.authz.Authorize(ctx, host, fleet.ActionWrite); err != nil { + return err + } + } + return nil +} diff --git a/server/service/service_hosts_test.go b/server/service/service_hosts_test.go index 33263d0a50..b8804c80d5 100644 --- a/server/service/service_hosts_test.go +++ b/server/service/service_hosts_test.go @@ -252,6 +252,9 @@ func TestHostAuth(t *testing.T) { ds.SaveHostFunc = func(ctx context.Context, host *fleet.Host) error { return nil } + ds.DeleteHostsFunc = func(ctx context.Context, ids []uint) error { + return nil + } var testCases = []struct { name string @@ -340,6 +343,12 @@ func TestHostAuth(t *testing.T) { err = svc.DeleteHost(ctx, 2) checkAuthErr(t, tt.shouldFailGlobalWrite, err) + err = svc.DeleteHosts(ctx, []uint{1}, fleet.HostListOptions{}, nil) + checkAuthErr(t, tt.shouldFailTeamWrite, err) + + err = svc.DeleteHosts(ctx, []uint{2}, fleet.HostListOptions{}, nil) + checkAuthErr(t, tt.shouldFailGlobalWrite, err) + err = svc.AddHostsToTeam(ctx, ptr.Uint(1), []uint{1}) checkAuthErr(t, tt.shouldFailTeamWrite, err)