diff --git a/changes/28700-add-bulk-execute-by-script b/changes/28700-add-bulk-execute-by-script new file mode 100644 index 0000000000..bbff1ce9a0 --- /dev/null +++ b/changes/28700-add-bulk-execute-by-script @@ -0,0 +1 @@ +- Added ability to execute scripts on up to 5,000 hosts at a time using filters diff --git a/server/fleet/service.go b/server/fleet/service.go index 6ba5c38801..456ac87669 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -1176,7 +1176,7 @@ type Service interface { BatchSetScripts(ctx context.Context, maybeTmID *uint, maybeTmName *string, payloads []ScriptPayload, dryRun bool) ([]ScriptResponse, error) // BatchScriptExecute runs a script on many hosts. It creates and returns a batch execution ID - BatchScriptExecute(ctx context.Context, scriptID uint, hostIDs []uint) (string, error) + BatchScriptExecute(ctx context.Context, scriptID uint, hostIDs []uint, filters *map[string]interface{}) (string, error) BatchScriptExecutionSummary(ctx context.Context, batchExecutionID string) (*BatchExecutionSummary, error) diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index 7b5bcfddef..ae61e7eb3f 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -6601,7 +6601,7 @@ func (s *integrationEnterpriseTestSuite) TestRunBatchScript() { s.DoJSON("POST", "/api/latest/fleet/scripts/run/batch", batchScriptRunRequest{ ScriptID: script.ID, HostIDs: []uint{host1.ID, host2.ID, host3Team1.ID}, - }, http.StatusBadRequest, &batchRes) + }, http.StatusUnprocessableEntity, &batchRes) require.Empty(t, batchRes.BatchExecutionID) // Bad script ID diff --git a/server/service/scripts.go b/server/service/scripts.go index d0ae32a42b..d18a87dfb7 100644 --- a/server/service/scripts.go +++ b/server/service/scripts.go @@ -1094,10 +1094,10 @@ func (svc *Service) authorizeScriptByID(ctx context.Context, scriptID uint, auth //////////////////////////////////////////////////////////////////////////////// type batchScriptRunRequest struct { - ScriptID uint `json:"script_id"` - HostIDs []uint `json:"host_ids"` + ScriptID uint `json:"script_id"` + HostIDs []uint `json:"host_ids"` + Filters *map[string]interface{} `json:"filters"` } - type batchScriptRunResponse struct { BatchExecutionID string `json:"batch_execution_id"` Err error `json:"error,omitempty"` @@ -1107,14 +1107,21 @@ func (r batchScriptRunResponse) Error() error { return r.Err } func batchScriptRunEndpoint(ctx context.Context, request any, svc fleet.Service) (fleet.Errorer, error) { req := request.(*batchScriptRunRequest) - batchID, err := svc.BatchScriptExecute(ctx, req.ScriptID, req.HostIDs) + batchID, err := svc.BatchScriptExecute(ctx, req.ScriptID, req.HostIDs, req.Filters) if err != nil { return batchScriptRunResponse{Err: err}, nil } return batchScriptRunResponse{BatchExecutionID: batchID}, nil } -func (svc *Service) BatchScriptExecute(ctx context.Context, scriptID uint, hostIDs []uint) (string, error) { +const MAX_BATCH_EXECUTION_HOSTS = 5000 + +func (svc *Service) BatchScriptExecute(ctx context.Context, scriptID uint, hostIDs []uint, filters *map[string]interface{}) (string, error) { + // If we are given both host IDs and filters, return an error + if len(hostIDs) > 0 && filters != nil { + return "", fleet.NewInvalidArgumentError("filters", "cannot specify both host_ids and filters") + } + // First check if scripts are disabled globally. If so, no need for further processing. cfg, err := svc.ds.AppConfig(ctx) if err != nil { @@ -1139,7 +1146,63 @@ func (svc *Service) BatchScriptExecute(ctx context.Context, scriptID uint, hostI userId = &ctxUser.ID } - batchID, err := svc.ds.BatchExecuteScript(ctx, userId, scriptID, hostIDs) + var hosts []*fleet.Host + + // If we are given filters, we need to get the hosts matching those filters + if filters != nil { + opt, lid, err := hostListOptionsFromFilters(filters) + if err != nil { + return "", err + } + + if opt == nil { + return "", fleet.NewInvalidArgumentError("filters", "filters must be a valid set of host list options") + } + + if opt.TeamFilter == nil { + return "", fleet.NewInvalidArgumentError("filters", "filters must include a team filter") + } + + filter := fleet.TeamFilter{User: ctxUser, IncludeObserver: true} + + // Load hosts, either from label if provided or from all hosts. + if lid != nil { + hosts, err = svc.ds.ListHostsInLabel(ctx, filter, *lid, *opt) + } else { + opt.DisableIssues = true // intentionally ignore failing policies + hosts, err = svc.ds.ListHosts(ctx, filter, *opt) + } + + if err != nil { + return "", err + } + } else { + // Get the hosts matching the host IDs + hosts, err = svc.ds.ListHostsLiteByIDs(ctx, hostIDs) + if err != nil { + return "", err + } + } + if len(hosts) == 0 { + return "", &fleet.BadRequestError{Message: "no hosts match the specified host IDs"} + } + + if len(hosts) > MAX_BATCH_EXECUTION_HOSTS { + return "", fleet.NewInvalidArgumentError("filters", "too_many_hosts") + } + + hostIDsToExecute := make([]uint, 0, len(hosts)) + for _, host := range hosts { + hostIDsToExecute = append(hostIDsToExecute, host.ID) + if host.TeamID == nil && script.TeamID == nil { + continue + } + if host.TeamID == nil || script.TeamID == nil || *host.TeamID != *script.TeamID { + return "", fleet.NewInvalidArgumentError("host_ids", "all hosts must be on the same team as the script") + } + } + + batchID, err := svc.ds.BatchExecuteScript(ctx, userId, scriptID, hostIDsToExecute) if err != nil { return "", fleet.NewUserMessageError(err, http.StatusBadRequest) } @@ -1147,7 +1210,7 @@ func (svc *Service) BatchScriptExecute(ctx context.Context, scriptID uint, hostI if err := svc.NewActivity(ctx, ctxUser, fleet.ActivityTypeRanScriptBatch{ ScriptName: script.Name, BatchExeuctionID: batchID, - HostCount: uint(len(hostIDs)), + HostCount: uint(len(hostIDsToExecute)), TeamID: script.TeamID, }); err != nil { return "", ctxerr.Wrap(ctx, err, "creating activity for batch run scripts") diff --git a/server/service/scripts_test.go b/server/service/scripts_test.go index 5051192400..3fae585ce3 100644 --- a/server/service/scripts_test.go +++ b/server/service/scripts_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "strings" "testing" "time" @@ -862,3 +863,108 @@ func TestHostScriptDetailsAuth(t *testing.T) { }) } } + +func TestBatchScriptExecute(t *testing.T) { + ds := new(mock.Store) + license := &fleet.LicenseInfo{Tier: fleet.TierPremium, Expiration: time.Now().Add(24 * time.Hour)} + svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: license, SkipCreateTestUsers: true}) + + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + + t.Run("error if hosts do not all belong to the same team as script", func(t *testing.T) { + ds.ListHostsLiteByIDsFunc = func(ctx context.Context, ids []uint) ([]*fleet.Host, error) { + return []*fleet.Host{ + {ID: 1, TeamID: ptr.Uint(1)}, + {ID: 2, TeamID: ptr.Uint(1)}, + {ID: 3, TeamID: ptr.Uint(2)}, + }, nil + } + ds.ScriptFunc = func(ctx context.Context, id uint) (*fleet.Script, error) { + if id == 1 { + return &fleet.Script{ID: id, TeamID: ptr.Uint(1)}, nil + } + return &fleet.Script{ID: id}, nil + } + ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) + _, err := svc.BatchScriptExecute(ctx, 1, []uint{1, 2, 3}, nil) + require.Error(t, err) + require.ErrorContains(t, err, "all hosts must be on the same team as the script") + }) + + t.Run("error if both host_ids and filters are specified", func(t *testing.T) { + ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) + _, err := svc.BatchScriptExecute(ctx, 1, []uint{1, 2, 3}, &map[string]interface{}{"foo": "bar"}) + require.Error(t, err) + require.ErrorContains(t, err, "cannot specify both host_ids and filters") + }) + + t.Run("error if filters are specified but no team_id", func(t *testing.T) { + ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) + _, err := svc.BatchScriptExecute(ctx, 1, nil, &map[string]interface{}{"label_id": float64(123)}) + require.Error(t, err) + require.ErrorContains(t, err, "filters must include a team filter") + }) + + t.Run("error if filters match too many hosts", func(t *testing.T) { + hosts := make([]*fleet.Host, 5001) + for i := 0; i < 5001; i++ { + hosts[i] = &fleet.Host{ID: uint(i + 1), TeamID: ptr.Uint(1)} // nolint:gosec // ignore G115 + } + ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) { + return hosts, nil + } + ds.ListHostsLiteByIDsFunc = func(ctx context.Context, ids []uint) ([]*fleet.Host, error) { + return hosts, nil + } + ds.ScriptFunc = func(ctx context.Context, id uint) (*fleet.Script, error) { + if id == 1 { + return &fleet.Script{ID: id, TeamID: ptr.Uint(1)}, nil + } + return &fleet.Script{ID: id}, nil + } + ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) + _, err := svc.BatchScriptExecute(ctx, 1, nil, &map[string]interface{}{"team_id": float64(1)}) + require.Error(t, err) + require.ErrorContains(t, err, "too_many_hosts") + }) + + t.Run("happy path", func(t *testing.T) { + var requestedHostIds []uint + ds.BatchExecuteScriptFunc = func(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint) (string, error) { + requestedHostIds = hostIDs + return "", errors.New("ok") + } + ds.ListHostsLiteByIDsFunc = func(ctx context.Context, ids []uint) ([]*fleet.Host, error) { + return []*fleet.Host{ + {ID: 1, TeamID: ptr.Uint(1)}, + {ID: 2, TeamID: ptr.Uint(1)}, + }, nil + } + ds.ScriptFunc = func(ctx context.Context, id uint) (*fleet.Script, error) { + if id == 1 { + return &fleet.Script{ID: id, TeamID: ptr.Uint(1)}, nil + } + return &fleet.Script{ID: id}, nil + } + ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) { + return []*fleet.Host{ + {ID: 3, TeamID: ptr.Uint(1)}, + {ID: 4, TeamID: ptr.Uint(1)}, + }, nil + } + + ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) + _, err := svc.BatchScriptExecute(ctx, 1, []uint{1, 2}, nil) + require.Error(t, err) + require.ErrorContains(t, err, "ok") + require.Equal(t, []uint{1, 2}, requestedHostIds) + + ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) + _, err = svc.BatchScriptExecute(ctx, 1, nil, &map[string]interface{}{"team_id": float64(1)}) + require.Error(t, err) + require.ErrorContains(t, err, "ok") + require.Equal(t, []uint{3, 4}, requestedHostIds) + }) +}