mirror of
https://github.com/fleetdm/fleet
synced 2026-05-23 08:58:41 +00:00
Add ability to bulk execute scripts based on filters (#29149)
for #28700 # Checklist for submitter - [X] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. ## Details This PR adds the ability to use filters to select a subset of hosts to run a script on, using the existing batch execution system. Due to the scale limitations of the framework, we limit this to 5,000 hosts (we may lift this limit in the future as we iterate on this feature). The implementation follows the same basic strategy as the "transfer hosts to team by filter" endpoint. If filters are supplied, they are used to get host records using `ListHosts` or `ListHostsInLabel`. If IDs are supplied, `ListHostsLiteByIDs` is used. From there, we do the same validation as in the previous iteration, and send the host IDs to the batch execution function. There are many avenues for optimization here, some of which I already have in a branch, but this is a very low-touch solution to get us larger batch sizes right now. To do this at true scale warrants some cross-team architecture discussions. ## Testing **Automated:** New automated tests were added for the existing `BatchExecuteScript` service method, and verified that they still pass with the code updates. **Manual testing:** * Tested running a script on a subset of hosts on a single page (no team and real team) * Tested running a script on a subset of hosts using a query filter (no team and real team) * Tested running a script on a subset of hosts using a label filter (no team and real team)
This commit is contained in:
parent
cbeb311b97
commit
bb09925d73
5 changed files with 179 additions and 9 deletions
1
changes/28700-add-bulk-execute-by-script
Normal file
1
changes/28700-add-bulk-execute-by-script
Normal file
|
|
@ -0,0 +1 @@
|
|||
- Added ability to execute scripts on up to 5,000 hosts at a time using filters
|
||||
|
|
@ -1176,7 +1176,7 @@ type Service interface {
|
|||
BatchSetScripts(ctx context.Context, maybeTmID *uint, maybeTmName *string, payloads []ScriptPayload, dryRun bool) ([]ScriptResponse, error)
|
||||
|
||||
// BatchScriptExecute runs a script on many hosts. It creates and returns a batch execution ID
|
||||
BatchScriptExecute(ctx context.Context, scriptID uint, hostIDs []uint) (string, error)
|
||||
BatchScriptExecute(ctx context.Context, scriptID uint, hostIDs []uint, filters *map[string]interface{}) (string, error)
|
||||
|
||||
BatchScriptExecutionSummary(ctx context.Context, batchExecutionID string) (*BatchExecutionSummary, error)
|
||||
|
||||
|
|
|
|||
|
|
@ -6601,7 +6601,7 @@ func (s *integrationEnterpriseTestSuite) TestRunBatchScript() {
|
|||
s.DoJSON("POST", "/api/latest/fleet/scripts/run/batch", batchScriptRunRequest{
|
||||
ScriptID: script.ID,
|
||||
HostIDs: []uint{host1.ID, host2.ID, host3Team1.ID},
|
||||
}, http.StatusBadRequest, &batchRes)
|
||||
}, http.StatusUnprocessableEntity, &batchRes)
|
||||
require.Empty(t, batchRes.BatchExecutionID)
|
||||
|
||||
// Bad script ID
|
||||
|
|
|
|||
|
|
@ -1094,10 +1094,10 @@ func (svc *Service) authorizeScriptByID(ctx context.Context, scriptID uint, auth
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type batchScriptRunRequest struct {
|
||||
ScriptID uint `json:"script_id"`
|
||||
HostIDs []uint `json:"host_ids"`
|
||||
ScriptID uint `json:"script_id"`
|
||||
HostIDs []uint `json:"host_ids"`
|
||||
Filters *map[string]interface{} `json:"filters"`
|
||||
}
|
||||
|
||||
type batchScriptRunResponse struct {
|
||||
BatchExecutionID string `json:"batch_execution_id"`
|
||||
Err error `json:"error,omitempty"`
|
||||
|
|
@ -1107,14 +1107,21 @@ func (r batchScriptRunResponse) Error() error { return r.Err }
|
|||
|
||||
func batchScriptRunEndpoint(ctx context.Context, request any, svc fleet.Service) (fleet.Errorer, error) {
|
||||
req := request.(*batchScriptRunRequest)
|
||||
batchID, err := svc.BatchScriptExecute(ctx, req.ScriptID, req.HostIDs)
|
||||
batchID, err := svc.BatchScriptExecute(ctx, req.ScriptID, req.HostIDs, req.Filters)
|
||||
if err != nil {
|
||||
return batchScriptRunResponse{Err: err}, nil
|
||||
}
|
||||
return batchScriptRunResponse{BatchExecutionID: batchID}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) BatchScriptExecute(ctx context.Context, scriptID uint, hostIDs []uint) (string, error) {
|
||||
const MAX_BATCH_EXECUTION_HOSTS = 5000
|
||||
|
||||
func (svc *Service) BatchScriptExecute(ctx context.Context, scriptID uint, hostIDs []uint, filters *map[string]interface{}) (string, error) {
|
||||
// If we are given both host IDs and filters, return an error
|
||||
if len(hostIDs) > 0 && filters != nil {
|
||||
return "", fleet.NewInvalidArgumentError("filters", "cannot specify both host_ids and filters")
|
||||
}
|
||||
|
||||
// First check if scripts are disabled globally. If so, no need for further processing.
|
||||
cfg, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -1139,7 +1146,63 @@ func (svc *Service) BatchScriptExecute(ctx context.Context, scriptID uint, hostI
|
|||
userId = &ctxUser.ID
|
||||
}
|
||||
|
||||
batchID, err := svc.ds.BatchExecuteScript(ctx, userId, scriptID, hostIDs)
|
||||
var hosts []*fleet.Host
|
||||
|
||||
// If we are given filters, we need to get the hosts matching those filters
|
||||
if filters != nil {
|
||||
opt, lid, err := hostListOptionsFromFilters(filters)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if opt == nil {
|
||||
return "", fleet.NewInvalidArgumentError("filters", "filters must be a valid set of host list options")
|
||||
}
|
||||
|
||||
if opt.TeamFilter == nil {
|
||||
return "", fleet.NewInvalidArgumentError("filters", "filters must include a team filter")
|
||||
}
|
||||
|
||||
filter := fleet.TeamFilter{User: ctxUser, IncludeObserver: true}
|
||||
|
||||
// Load hosts, either from label if provided or from all hosts.
|
||||
if lid != nil {
|
||||
hosts, err = svc.ds.ListHostsInLabel(ctx, filter, *lid, *opt)
|
||||
} else {
|
||||
opt.DisableIssues = true // intentionally ignore failing policies
|
||||
hosts, err = svc.ds.ListHosts(ctx, filter, *opt)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
// Get the hosts matching the host IDs
|
||||
hosts, err = svc.ds.ListHostsLiteByIDs(ctx, hostIDs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if len(hosts) == 0 {
|
||||
return "", &fleet.BadRequestError{Message: "no hosts match the specified host IDs"}
|
||||
}
|
||||
|
||||
if len(hosts) > MAX_BATCH_EXECUTION_HOSTS {
|
||||
return "", fleet.NewInvalidArgumentError("filters", "too_many_hosts")
|
||||
}
|
||||
|
||||
hostIDsToExecute := make([]uint, 0, len(hosts))
|
||||
for _, host := range hosts {
|
||||
hostIDsToExecute = append(hostIDsToExecute, host.ID)
|
||||
if host.TeamID == nil && script.TeamID == nil {
|
||||
continue
|
||||
}
|
||||
if host.TeamID == nil || script.TeamID == nil || *host.TeamID != *script.TeamID {
|
||||
return "", fleet.NewInvalidArgumentError("host_ids", "all hosts must be on the same team as the script")
|
||||
}
|
||||
}
|
||||
|
||||
batchID, err := svc.ds.BatchExecuteScript(ctx, userId, scriptID, hostIDsToExecute)
|
||||
if err != nil {
|
||||
return "", fleet.NewUserMessageError(err, http.StatusBadRequest)
|
||||
}
|
||||
|
|
@ -1147,7 +1210,7 @@ func (svc *Service) BatchScriptExecute(ctx context.Context, scriptID uint, hostI
|
|||
if err := svc.NewActivity(ctx, ctxUser, fleet.ActivityTypeRanScriptBatch{
|
||||
ScriptName: script.Name,
|
||||
BatchExeuctionID: batchID,
|
||||
HostCount: uint(len(hostIDs)),
|
||||
HostCount: uint(len(hostIDsToExecute)),
|
||||
TeamID: script.TeamID,
|
||||
}); err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "creating activity for batch run scripts")
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package service
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -862,3 +863,108 @@ func TestHostScriptDetailsAuth(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchScriptExecute(t *testing.T) {
|
||||
ds := new(mock.Store)
|
||||
license := &fleet.LicenseInfo{Tier: fleet.TierPremium, Expiration: time.Now().Add(24 * time.Hour)}
|
||||
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: license, SkipCreateTestUsers: true})
|
||||
|
||||
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
return &fleet.AppConfig{}, nil
|
||||
}
|
||||
|
||||
t.Run("error if hosts do not all belong to the same team as script", func(t *testing.T) {
|
||||
ds.ListHostsLiteByIDsFunc = func(ctx context.Context, ids []uint) ([]*fleet.Host, error) {
|
||||
return []*fleet.Host{
|
||||
{ID: 1, TeamID: ptr.Uint(1)},
|
||||
{ID: 2, TeamID: ptr.Uint(1)},
|
||||
{ID: 3, TeamID: ptr.Uint(2)},
|
||||
}, nil
|
||||
}
|
||||
ds.ScriptFunc = func(ctx context.Context, id uint) (*fleet.Script, error) {
|
||||
if id == 1 {
|
||||
return &fleet.Script{ID: id, TeamID: ptr.Uint(1)}, nil
|
||||
}
|
||||
return &fleet.Script{ID: id}, nil
|
||||
}
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})
|
||||
_, err := svc.BatchScriptExecute(ctx, 1, []uint{1, 2, 3}, nil)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "all hosts must be on the same team as the script")
|
||||
})
|
||||
|
||||
t.Run("error if both host_ids and filters are specified", func(t *testing.T) {
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})
|
||||
_, err := svc.BatchScriptExecute(ctx, 1, []uint{1, 2, 3}, &map[string]interface{}{"foo": "bar"})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "cannot specify both host_ids and filters")
|
||||
})
|
||||
|
||||
t.Run("error if filters are specified but no team_id", func(t *testing.T) {
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})
|
||||
_, err := svc.BatchScriptExecute(ctx, 1, nil, &map[string]interface{}{"label_id": float64(123)})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "filters must include a team filter")
|
||||
})
|
||||
|
||||
t.Run("error if filters match too many hosts", func(t *testing.T) {
|
||||
hosts := make([]*fleet.Host, 5001)
|
||||
for i := 0; i < 5001; i++ {
|
||||
hosts[i] = &fleet.Host{ID: uint(i + 1), TeamID: ptr.Uint(1)} // nolint:gosec // ignore G115
|
||||
}
|
||||
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
|
||||
return hosts, nil
|
||||
}
|
||||
ds.ListHostsLiteByIDsFunc = func(ctx context.Context, ids []uint) ([]*fleet.Host, error) {
|
||||
return hosts, nil
|
||||
}
|
||||
ds.ScriptFunc = func(ctx context.Context, id uint) (*fleet.Script, error) {
|
||||
if id == 1 {
|
||||
return &fleet.Script{ID: id, TeamID: ptr.Uint(1)}, nil
|
||||
}
|
||||
return &fleet.Script{ID: id}, nil
|
||||
}
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})
|
||||
_, err := svc.BatchScriptExecute(ctx, 1, nil, &map[string]interface{}{"team_id": float64(1)})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "too_many_hosts")
|
||||
})
|
||||
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
var requestedHostIds []uint
|
||||
ds.BatchExecuteScriptFunc = func(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint) (string, error) {
|
||||
requestedHostIds = hostIDs
|
||||
return "", errors.New("ok")
|
||||
}
|
||||
ds.ListHostsLiteByIDsFunc = func(ctx context.Context, ids []uint) ([]*fleet.Host, error) {
|
||||
return []*fleet.Host{
|
||||
{ID: 1, TeamID: ptr.Uint(1)},
|
||||
{ID: 2, TeamID: ptr.Uint(1)},
|
||||
}, nil
|
||||
}
|
||||
ds.ScriptFunc = func(ctx context.Context, id uint) (*fleet.Script, error) {
|
||||
if id == 1 {
|
||||
return &fleet.Script{ID: id, TeamID: ptr.Uint(1)}, nil
|
||||
}
|
||||
return &fleet.Script{ID: id}, nil
|
||||
}
|
||||
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
|
||||
return []*fleet.Host{
|
||||
{ID: 3, TeamID: ptr.Uint(1)},
|
||||
{ID: 4, TeamID: ptr.Uint(1)},
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})
|
||||
_, err := svc.BatchScriptExecute(ctx, 1, []uint{1, 2}, nil)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "ok")
|
||||
require.Equal(t, []uint{1, 2}, requestedHostIds)
|
||||
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})
|
||||
_, err = svc.BatchScriptExecute(ctx, 1, nil, &map[string]interface{}{"team_id": float64(1)})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "ok")
|
||||
require.Equal(t, []uint{3, 4}, requestedHostIds)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue