Add ability to bulk execute scripts based on filters (#29149)

for #28700 

# Checklist for submitter

- [X] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.

## Details

This PR adds the ability to use filters to select a subset of hosts to
run a script on, using the existing batch execution system. Due to the
scale limitations of the framework, we limit this to 5,000 hosts (we may
lift this limit in the future as we iterate on this feature).

The implementation follows the same basic strategy as the "transfer
hosts to team by filter" endpoint. If filters are supplied, they are
used to get host records using `ListHosts` or `ListHostsInLabel`. If IDs
are supplied, `ListHostsLiteByIDs` is used. From there, we do the same
validation as in the previous iteration, and send the host IDs to the
batch execution function.

There are many avenues for optimization here, some of which I already
have in a branch, but this is a very low-touch solution to get us larger
batch sizes right now. To do this at true scale warrants some cross-team
architecture discussions.

## Testing

**Automated:**

New automated tests were added for the existing `BatchExecuteScript`
service method, and verified that they still pass with the code updates.

**Manual testing:**

* Tested running a script on a subset of hosts on a single page (no team
and real team)
* Tested running a script on a subset of hosts using a query filter (no
team and real team)
* Tested running a script on a subset of hosts using a label filter (no
team and real team)
This commit is contained in:
Scott Gress 2025-05-22 16:44:34 -05:00 committed by GitHub
parent cbeb311b97
commit bb09925d73
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 179 additions and 9 deletions

View file

@ -0,0 +1 @@
- Added ability to execute scripts on up to 5,000 hosts at a time using filters

View file

@ -1176,7 +1176,7 @@ type Service interface {
BatchSetScripts(ctx context.Context, maybeTmID *uint, maybeTmName *string, payloads []ScriptPayload, dryRun bool) ([]ScriptResponse, error)
// BatchScriptExecute runs a script on many hosts. It creates and returns a batch execution ID
BatchScriptExecute(ctx context.Context, scriptID uint, hostIDs []uint) (string, error)
BatchScriptExecute(ctx context.Context, scriptID uint, hostIDs []uint, filters *map[string]interface{}) (string, error)
BatchScriptExecutionSummary(ctx context.Context, batchExecutionID string) (*BatchExecutionSummary, error)

View file

@ -6601,7 +6601,7 @@ func (s *integrationEnterpriseTestSuite) TestRunBatchScript() {
s.DoJSON("POST", "/api/latest/fleet/scripts/run/batch", batchScriptRunRequest{
ScriptID: script.ID,
HostIDs: []uint{host1.ID, host2.ID, host3Team1.ID},
}, http.StatusBadRequest, &batchRes)
}, http.StatusUnprocessableEntity, &batchRes)
require.Empty(t, batchRes.BatchExecutionID)
// Bad script ID

View file

@ -1094,10 +1094,10 @@ func (svc *Service) authorizeScriptByID(ctx context.Context, scriptID uint, auth
////////////////////////////////////////////////////////////////////////////////
type batchScriptRunRequest struct {
ScriptID uint `json:"script_id"`
HostIDs []uint `json:"host_ids"`
ScriptID uint `json:"script_id"`
HostIDs []uint `json:"host_ids"`
Filters *map[string]interface{} `json:"filters"`
}
type batchScriptRunResponse struct {
BatchExecutionID string `json:"batch_execution_id"`
Err error `json:"error,omitempty"`
@ -1107,14 +1107,21 @@ func (r batchScriptRunResponse) Error() error { return r.Err }
func batchScriptRunEndpoint(ctx context.Context, request any, svc fleet.Service) (fleet.Errorer, error) {
req := request.(*batchScriptRunRequest)
batchID, err := svc.BatchScriptExecute(ctx, req.ScriptID, req.HostIDs)
batchID, err := svc.BatchScriptExecute(ctx, req.ScriptID, req.HostIDs, req.Filters)
if err != nil {
return batchScriptRunResponse{Err: err}, nil
}
return batchScriptRunResponse{BatchExecutionID: batchID}, nil
}
func (svc *Service) BatchScriptExecute(ctx context.Context, scriptID uint, hostIDs []uint) (string, error) {
const MAX_BATCH_EXECUTION_HOSTS = 5000
func (svc *Service) BatchScriptExecute(ctx context.Context, scriptID uint, hostIDs []uint, filters *map[string]interface{}) (string, error) {
// If we are given both host IDs and filters, return an error
if len(hostIDs) > 0 && filters != nil {
return "", fleet.NewInvalidArgumentError("filters", "cannot specify both host_ids and filters")
}
// First check if scripts are disabled globally. If so, no need for further processing.
cfg, err := svc.ds.AppConfig(ctx)
if err != nil {
@ -1139,7 +1146,63 @@ func (svc *Service) BatchScriptExecute(ctx context.Context, scriptID uint, hostI
userId = &ctxUser.ID
}
batchID, err := svc.ds.BatchExecuteScript(ctx, userId, scriptID, hostIDs)
var hosts []*fleet.Host
// If we are given filters, we need to get the hosts matching those filters
if filters != nil {
opt, lid, err := hostListOptionsFromFilters(filters)
if err != nil {
return "", err
}
if opt == nil {
return "", fleet.NewInvalidArgumentError("filters", "filters must be a valid set of host list options")
}
if opt.TeamFilter == nil {
return "", fleet.NewInvalidArgumentError("filters", "filters must include a team filter")
}
filter := fleet.TeamFilter{User: ctxUser, IncludeObserver: true}
// Load hosts, either from label if provided or from all hosts.
if lid != nil {
hosts, err = svc.ds.ListHostsInLabel(ctx, filter, *lid, *opt)
} else {
opt.DisableIssues = true // intentionally ignore failing policies
hosts, err = svc.ds.ListHosts(ctx, filter, *opt)
}
if err != nil {
return "", err
}
} else {
// Get the hosts matching the host IDs
hosts, err = svc.ds.ListHostsLiteByIDs(ctx, hostIDs)
if err != nil {
return "", err
}
}
if len(hosts) == 0 {
return "", &fleet.BadRequestError{Message: "no hosts match the specified host IDs"}
}
if len(hosts) > MAX_BATCH_EXECUTION_HOSTS {
return "", fleet.NewInvalidArgumentError("filters", "too_many_hosts")
}
hostIDsToExecute := make([]uint, 0, len(hosts))
for _, host := range hosts {
hostIDsToExecute = append(hostIDsToExecute, host.ID)
if host.TeamID == nil && script.TeamID == nil {
continue
}
if host.TeamID == nil || script.TeamID == nil || *host.TeamID != *script.TeamID {
return "", fleet.NewInvalidArgumentError("host_ids", "all hosts must be on the same team as the script")
}
}
batchID, err := svc.ds.BatchExecuteScript(ctx, userId, scriptID, hostIDsToExecute)
if err != nil {
return "", fleet.NewUserMessageError(err, http.StatusBadRequest)
}
@ -1147,7 +1210,7 @@ func (svc *Service) BatchScriptExecute(ctx context.Context, scriptID uint, hostI
if err := svc.NewActivity(ctx, ctxUser, fleet.ActivityTypeRanScriptBatch{
ScriptName: script.Name,
BatchExeuctionID: batchID,
HostCount: uint(len(hostIDs)),
HostCount: uint(len(hostIDsToExecute)),
TeamID: script.TeamID,
}); err != nil {
return "", ctxerr.Wrap(ctx, err, "creating activity for batch run scripts")

View file

@ -2,6 +2,7 @@ package service
import (
"context"
"errors"
"strings"
"testing"
"time"
@ -862,3 +863,108 @@ func TestHostScriptDetailsAuth(t *testing.T) {
})
}
}
func TestBatchScriptExecute(t *testing.T) {
ds := new(mock.Store)
license := &fleet.LicenseInfo{Tier: fleet.TierPremium, Expiration: time.Now().Add(24 * time.Hour)}
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: license, SkipCreateTestUsers: true})
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
t.Run("error if hosts do not all belong to the same team as script", func(t *testing.T) {
ds.ListHostsLiteByIDsFunc = func(ctx context.Context, ids []uint) ([]*fleet.Host, error) {
return []*fleet.Host{
{ID: 1, TeamID: ptr.Uint(1)},
{ID: 2, TeamID: ptr.Uint(1)},
{ID: 3, TeamID: ptr.Uint(2)},
}, nil
}
ds.ScriptFunc = func(ctx context.Context, id uint) (*fleet.Script, error) {
if id == 1 {
return &fleet.Script{ID: id, TeamID: ptr.Uint(1)}, nil
}
return &fleet.Script{ID: id}, nil
}
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})
_, err := svc.BatchScriptExecute(ctx, 1, []uint{1, 2, 3}, nil)
require.Error(t, err)
require.ErrorContains(t, err, "all hosts must be on the same team as the script")
})
t.Run("error if both host_ids and filters are specified", func(t *testing.T) {
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})
_, err := svc.BatchScriptExecute(ctx, 1, []uint{1, 2, 3}, &map[string]interface{}{"foo": "bar"})
require.Error(t, err)
require.ErrorContains(t, err, "cannot specify both host_ids and filters")
})
t.Run("error if filters are specified but no team_id", func(t *testing.T) {
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})
_, err := svc.BatchScriptExecute(ctx, 1, nil, &map[string]interface{}{"label_id": float64(123)})
require.Error(t, err)
require.ErrorContains(t, err, "filters must include a team filter")
})
t.Run("error if filters match too many hosts", func(t *testing.T) {
hosts := make([]*fleet.Host, 5001)
for i := 0; i < 5001; i++ {
hosts[i] = &fleet.Host{ID: uint(i + 1), TeamID: ptr.Uint(1)} // nolint:gosec // ignore G115
}
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
return hosts, nil
}
ds.ListHostsLiteByIDsFunc = func(ctx context.Context, ids []uint) ([]*fleet.Host, error) {
return hosts, nil
}
ds.ScriptFunc = func(ctx context.Context, id uint) (*fleet.Script, error) {
if id == 1 {
return &fleet.Script{ID: id, TeamID: ptr.Uint(1)}, nil
}
return &fleet.Script{ID: id}, nil
}
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})
_, err := svc.BatchScriptExecute(ctx, 1, nil, &map[string]interface{}{"team_id": float64(1)})
require.Error(t, err)
require.ErrorContains(t, err, "too_many_hosts")
})
t.Run("happy path", func(t *testing.T) {
var requestedHostIds []uint
ds.BatchExecuteScriptFunc = func(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint) (string, error) {
requestedHostIds = hostIDs
return "", errors.New("ok")
}
ds.ListHostsLiteByIDsFunc = func(ctx context.Context, ids []uint) ([]*fleet.Host, error) {
return []*fleet.Host{
{ID: 1, TeamID: ptr.Uint(1)},
{ID: 2, TeamID: ptr.Uint(1)},
}, nil
}
ds.ScriptFunc = func(ctx context.Context, id uint) (*fleet.Script, error) {
if id == 1 {
return &fleet.Script{ID: id, TeamID: ptr.Uint(1)}, nil
}
return &fleet.Script{ID: id}, nil
}
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
return []*fleet.Host{
{ID: 3, TeamID: ptr.Uint(1)},
{ID: 4, TeamID: ptr.Uint(1)},
}, nil
}
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})
_, err := svc.BatchScriptExecute(ctx, 1, []uint{1, 2}, nil)
require.Error(t, err)
require.ErrorContains(t, err, "ok")
require.Equal(t, []uint{1, 2}, requestedHostIds)
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})
_, err = svc.BatchScriptExecute(ctx, 1, nil, &map[string]interface{}{"team_id": float64(1)})
require.Error(t, err)
require.ErrorContains(t, err, "ok")
require.Equal(t, []uint{3, 4}, requestedHostIds)
})
}