mirror of
https://github.com/fleetdm/fleet
synced 2026-05-23 17:08:53 +00:00
Prevent thundering herd when applying large number of policies on large number of hosts (#13552)
#13527 (Adding @mna to double check the changes in the async implementation of policy result storage) This PR also adds the osquery-perf changes needed to define the count of macOS and Windows hosts. - [X] Changes file added for user-visible changes in `changes/` or `orbit/changes/`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - ~[ ] Documented any API changes (docs/Using-Fleet/REST-API.md or docs/Contributing/API-for-contributors.md)~ - ~[ ] Documented any permissions changes (docs/Using Fleet/manage-access.md)~ - [X] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [X] Added support on fleet's osquery simulator `cmd/osquery-perf` for new osquery data ingestion features. - [X] Added/updated tests - [X] Manual QA for all new/changed functionality - ~For Orbit and Fleet Desktop changes:~ - ~[ ] Manual QA must be performed in the three main OSs, macOS, Windows and Linux.~ - ~[ ] Auto-update manual QA, from released version of component to new version (see [tools/tuf/test](../tools/tuf/test/README.md)).~ Test with 80k hosts: 70k simulated macOS, 10k simulated Windows. Apply Windows policies first, then apply macOS policies: ``` fleetctl apply -f ee/cis/win-10/cis-policy-queries.yml # Leave running for some time fleetctl apply -f ee/cis/macos-13/cis-policy-queries.yml ``` After applying CIS policies previous to these changes:  After applying these changes and applying the same policies: 
This commit is contained in:
parent
6637ea6517
commit
9142c5de79
10 changed files with 348 additions and 83 deletions
4
changes/13527-applying-policies-at-scale
Normal file
4
changes/13527-applying-policies-at-scale
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
* Improved performance at scale when applying hundreds of policies to thousands of hosts via `fleetctl apply`.
|
||||
IMPORTANT: In previous versions of Fleet there's a performance issue (thundering herd) when applying hundreds of
|
||||
policies on a large number of hosts. To avoid this, make sure to deploy this version of Fleet, and make sure Fleet
|
||||
is running for at least 1h (or the configured `FLEET_OSQUERY_POLICY_UPDATE_INTERVAL`) before applying the policies.
|
||||
|
|
@ -1406,8 +1406,10 @@ func main() {
|
|||
orbitProb = flag.Float64("orbit_prob", 0.5, "Probability of a host being identified as orbit install [0, 1]")
|
||||
munkiIssueProb = flag.Float64("munki_issue_prob", 0.5, "Probability of a host having munki issues (note that ~50% of hosts have munki installed) [0, 1]")
|
||||
munkiIssueCount = flag.Int("munki_issue_count", 10, "Number of munki issues reported by hosts identified to have munki issues")
|
||||
osTemplates = flag.String("os_templates", "mac10.14.6", fmt.Sprintf("Comma separated list of host OS templates to use (any of %v, with or without the .tmpl extension)", allowedTemplateNames))
|
||||
emptySerialProb = flag.Float64("empty_serial_prob", 0.1, "Probability of a host having no serial number [0, 1]")
|
||||
// E.g. when running with `-host_count=10`, you can set host count for each template the following way:
|
||||
// `-os_templates=windows_11.tmpl:3,mac10.14.6.tmpl:4,ubuntu_22.04.tmpl:3`
|
||||
osTemplates = flag.String("os_templates", "mac10.14.6", fmt.Sprintf("Comma separated list of host OS templates to use and optionally their host count separated by ':' (any of %v, with or without the .tmpl extension)", allowedTemplateNames))
|
||||
emptySerialProb = flag.Float64("empty_serial_prob", 0.1, "Probability of a host having no serial number [0, 1]")
|
||||
|
||||
mdmProb = flag.Float64("mdm_prob", 0.0, "Probability of a host enrolling via MDM (for macOS) [0, 1]")
|
||||
mdmSCEPChallenge = flag.String("mdm_scep_challenge", "", "SCEP challenge to use when running MDM enroll")
|
||||
|
|
@ -1434,9 +1436,20 @@ func main() {
|
|||
log.Fatalf("Argument unique_software_uninstall_count cannot be bigger than unique_software_count")
|
||||
}
|
||||
|
||||
var tmpls []*template.Template
|
||||
tmplsm := make(map[*template.Template]int)
|
||||
requestedTemplates := strings.Split(*osTemplates, ",")
|
||||
tmplsTotalHostCount := 0
|
||||
for _, nm := range requestedTemplates {
|
||||
numberOfHosts := 0
|
||||
if strings.Contains(nm, ":") {
|
||||
parts := strings.Split(nm, ":")
|
||||
nm = parts[0]
|
||||
hc, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid template host count: %s", parts[1])
|
||||
}
|
||||
numberOfHosts = int(hc)
|
||||
}
|
||||
if !strings.HasSuffix(nm, ".tmpl") {
|
||||
nm += ".tmpl"
|
||||
}
|
||||
|
|
@ -1448,7 +1461,11 @@ func main() {
|
|||
if err != nil {
|
||||
log.Fatal("parse templates: ", err)
|
||||
}
|
||||
tmpls = append(tmpls, tmpl)
|
||||
tmplsm[tmpl] = numberOfHosts
|
||||
tmplsTotalHostCount += numberOfHosts
|
||||
}
|
||||
if tmplsTotalHostCount != 0 && tmplsTotalHostCount != *hostCount {
|
||||
log.Fatalf("Invalid host count in templates: total=%d vs host_count=%d", tmplsTotalHostCount, *hostCount)
|
||||
}
|
||||
|
||||
// Spread starts over the interval to prevent thundering herd
|
||||
|
|
@ -1463,8 +1480,28 @@ func main() {
|
|||
nodeKeyManager.LoadKeys()
|
||||
}
|
||||
|
||||
var tmplss []*template.Template
|
||||
for tmpl := range tmplsm {
|
||||
tmplss = append(tmplss, tmpl)
|
||||
}
|
||||
|
||||
for i := 0; i < *hostCount; i++ {
|
||||
tmpl := tmpls[i%len(tmpls)]
|
||||
var tmpl *template.Template
|
||||
if tmplsTotalHostCount > 0 {
|
||||
for tmpl_, hostCount := range tmplsm {
|
||||
if hostCount > 0 {
|
||||
tmpl = tmpl_
|
||||
tmplsm[tmpl_] = tmplsm[tmpl_] - 1
|
||||
break
|
||||
}
|
||||
}
|
||||
if tmpl == nil {
|
||||
log.Fatalf("Failed to determine template for host: %d", i)
|
||||
}
|
||||
} else {
|
||||
tmpl = tmplss[i%len(tmplss)]
|
||||
}
|
||||
|
||||
a := newAgent(i+1,
|
||||
*serverURL,
|
||||
*enrollSecret,
|
||||
|
|
|
|||
|
|
@ -135,6 +135,7 @@ func TestHosts(t *testing.T) {
|
|||
{"ReplaceHostBatteries", testHostsReplaceHostBatteries},
|
||||
{"CountHostsNotResponding", testCountHostsNotResponding},
|
||||
{"FailingPoliciesCount", testFailingPoliciesCount},
|
||||
{"HostRecordNoPolicies", testHostsRecordNoPolicies},
|
||||
{"SetOrUpdateHostDisksSpace", testHostsSetOrUpdateHostDisksSpace},
|
||||
{"HostIDsByOSID", testHostIDsByOSID},
|
||||
{"SetOrUpdateHostDisksEncryption", testHostsSetOrUpdateHostDisksEncryption},
|
||||
|
|
@ -6159,6 +6160,55 @@ func testFailingPoliciesCount(t *testing.T, ds *Datastore) {
|
|||
})
|
||||
}
|
||||
|
||||
func testHostsRecordNoPolicies(t *testing.T, ds *Datastore) {
|
||||
initialTime := time.Now()
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err := ds.NewHost(context.Background(), &fleet.Host{
|
||||
DetailUpdatedAt: initialTime,
|
||||
LabelUpdatedAt: initialTime,
|
||||
PolicyUpdatedAt: initialTime,
|
||||
SeenTime: initialTime.Add(-time.Duration(i) * time.Minute),
|
||||
OsqueryHostID: ptr.String(strconv.Itoa(i)),
|
||||
NodeKey: ptr.String(fmt.Sprintf("%d", i)),
|
||||
UUID: fmt.Sprintf("%d", i),
|
||||
Hostname: fmt.Sprintf("foo.local%d", i),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
filter := fleet.TeamFilter{User: test.UserAdmin}
|
||||
|
||||
hosts := listHostsCheckCount(t, ds, filter, fleet.HostListOptions{}, 2)
|
||||
require.Len(t, hosts, 2)
|
||||
|
||||
h1 := hosts[0]
|
||||
h2 := hosts[1]
|
||||
|
||||
assert.WithinDuration(t, initialTime, h1.PolicyUpdatedAt, 1*time.Second)
|
||||
assert.Zero(t, h1.HostIssues.FailingPoliciesCount)
|
||||
assert.Zero(t, h1.HostIssues.TotalIssuesCount)
|
||||
assert.WithinDuration(t, initialTime, h2.PolicyUpdatedAt, 1*time.Second)
|
||||
assert.Zero(t, h2.HostIssues.FailingPoliciesCount)
|
||||
assert.Zero(t, h2.HostIssues.TotalIssuesCount)
|
||||
|
||||
policyUpdatedAt := initialTime.Add(1 * time.Hour)
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, nil, policyUpdatedAt, false))
|
||||
|
||||
hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{}, 2)
|
||||
require.Len(t, hosts, 2)
|
||||
|
||||
h1 = hosts[0]
|
||||
h2 = hosts[1]
|
||||
|
||||
assert.WithinDuration(t, policyUpdatedAt, h1.PolicyUpdatedAt, 1*time.Second)
|
||||
assert.Zero(t, h1.HostIssues.FailingPoliciesCount)
|
||||
assert.Zero(t, h1.HostIssues.TotalIssuesCount)
|
||||
assert.WithinDuration(t, initialTime, h2.PolicyUpdatedAt, 1*time.Second)
|
||||
assert.Zero(t, h2.HostIssues.FailingPoliciesCount)
|
||||
assert.Zero(t, h2.HostIssues.TotalIssuesCount)
|
||||
}
|
||||
|
||||
func testHostsSetOrUpdateHostDisksSpace(t *testing.T, ds *Datastore) {
|
||||
host, err := ds.NewHost(context.Background(), &fleet.Host{
|
||||
DetailUpdatedAt: time.Now(),
|
||||
|
|
|
|||
|
|
@ -214,21 +214,23 @@ func filterNotExecuted(results map[uint]*bool) map[uint]bool {
|
|||
}
|
||||
|
||||
func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error {
|
||||
// Sort the results to have generated SQL queries ordered to minimize
|
||||
// deadlocks. See https://github.com/fleetdm/fleet/issues/1146.
|
||||
orderedIDs := make([]uint, 0, len(results))
|
||||
for policyID := range results {
|
||||
orderedIDs = append(orderedIDs, policyID)
|
||||
}
|
||||
sort.Slice(orderedIDs, func(i, j int) bool { return orderedIDs[i] < orderedIDs[j] })
|
||||
|
||||
// Loop through results, collecting which labels we need to insert/update
|
||||
vals := []interface{}{}
|
||||
bindvars := []string{}
|
||||
for _, policyID := range orderedIDs {
|
||||
matches := results[policyID]
|
||||
bindvars = append(bindvars, "(?,?,?,?)")
|
||||
vals = append(vals, updated, policyID, host.ID, matches)
|
||||
if len(results) > 0 {
|
||||
// Sort the results to have generated SQL queries ordered to minimize
|
||||
// deadlocks. See https://github.com/fleetdm/fleet/issues/1146.
|
||||
orderedIDs := make([]uint, 0, len(results))
|
||||
for policyID := range results {
|
||||
orderedIDs = append(orderedIDs, policyID)
|
||||
}
|
||||
sort.Slice(orderedIDs, func(i, j int) bool { return orderedIDs[i] < orderedIDs[j] })
|
||||
|
||||
// Loop through results, collecting which labels we need to insert/update
|
||||
for _, policyID := range orderedIDs {
|
||||
matches := results[policyID]
|
||||
bindvars = append(bindvars, "(?,?,?,?)")
|
||||
vals = append(vals, updated, policyID, host.ID, matches)
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: the insert of policy membership that follows must be kept in sync
|
||||
|
|
@ -238,16 +240,17 @@ func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *flee
|
|||
// semantically equivalent, even though here it processes a single host and
|
||||
// in async mode it processes a batch of hosts).
|
||||
|
||||
query := fmt.Sprintf(
|
||||
`INSERT INTO policy_membership (updated_at, policy_id, host_id, passes)
|
||||
VALUES %s ON DUPLICATE KEY UPDATE updated_at=VALUES(updated_at), passes=VALUES(passes)`,
|
||||
strings.Join(bindvars, ","),
|
||||
)
|
||||
|
||||
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
||||
_, err := tx.ExecContext(ctx, query, vals...)
|
||||
if err != nil {
|
||||
return ctxerr.Wrapf(ctx, err, "insert policy_membership (%v)", vals)
|
||||
if len(results) > 0 {
|
||||
query := fmt.Sprintf(
|
||||
`INSERT INTO policy_membership (updated_at, policy_id, host_id, passes)
|
||||
VALUES %s ON DUPLICATE KEY UPDATE updated_at=VALUES(updated_at), passes=VALUES(passes)`,
|
||||
strings.Join(bindvars, ","),
|
||||
)
|
||||
_, err := tx.ExecContext(ctx, query, vals...)
|
||||
if err != nil {
|
||||
return ctxerr.Wrapf(ctx, err, "insert policy_membership (%v)", vals)
|
||||
}
|
||||
}
|
||||
|
||||
// if we are deferring host updates, we return at this point and do the change outside of the tx
|
||||
|
|
@ -255,8 +258,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *flee
|
|||
return nil
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, `UPDATE hosts SET policy_updated_at = ? WHERE id=?`, updated, host.ID)
|
||||
if err != nil {
|
||||
if _, err := tx.ExecContext(ctx, `UPDATE hosts SET policy_updated_at = ? WHERE id=?`, updated, host.ID); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "updating hosts policy updated at")
|
||||
}
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -667,6 +667,7 @@ type Datastore interface {
|
|||
FlippingPoliciesForHost(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error)
|
||||
|
||||
// RecordPolicyQueryExecutions records the execution results of the policies for the given host.
|
||||
// Even if `results` is empty, the host's `policy_updated_at` will be updated.
|
||||
RecordPolicyQueryExecutions(ctx context.Context, host *Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error
|
||||
|
||||
// RecordLabelQueryExecutions saves the results of label queries. The results map is a map of label id -> whether or
|
||||
|
|
|
|||
|
|
@ -22,10 +22,8 @@ const (
|
|||
policyPassKeysMinTTL = 7 * 24 * time.Hour // 1 week
|
||||
)
|
||||
|
||||
var (
|
||||
// redis list will be LTRIM'd if there are more policy IDs than this.
|
||||
maxRedisPolicyResultsPerHost = 1000
|
||||
)
|
||||
// redis list will be LTRIM'd if there are more policy IDs than this.
|
||||
var maxRedisPolicyResultsPerHost = 1000
|
||||
|
||||
func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error {
|
||||
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
|
||||
|
|
@ -51,36 +49,57 @@ func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host
|
|||
ttl = maxTTL
|
||||
}
|
||||
|
||||
// KEYS[1]: keyList (policyPassHostKey)
|
||||
// KEYS[2]: keyTs (policyPassReportedKey)
|
||||
// ARGV[1]: timestamp for "reported at"
|
||||
// ARGV[2]: max policy results to keep per host (list is trimmed to that size)
|
||||
// ARGV[3]: ttl for both keys
|
||||
// ARGV[4..]: policy_id=pass entries to LPUSH to the list
|
||||
script := redigo.NewScript(2, `
|
||||
redis.call('LPUSH', KEYS[1], unpack(ARGV, 4))
|
||||
redis.call('LTRIM', KEYS[1], 0, ARGV[2])
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[3])
|
||||
redis.call('SET', KEYS[2], ARGV[1])
|
||||
return redis.call('EXPIRE', KEYS[2], ARGV[3])
|
||||
`)
|
||||
// There are two versions of the script (1) and (2).
|
||||
// Script (1) is used when the are no policy results to report and
|
||||
// script (2) is used when there are policy results.
|
||||
|
||||
// convert results to LPUSH arguments, store as policy_id=1 for pass,
|
||||
// policy_id=-1 for fail, policy_id=0 for null result.
|
||||
args := make(redigo.Args, 0, 5+len(results))
|
||||
args = args.Add(keyList, keyTs, ts.Unix(), maxRedisPolicyResultsPerHost, int(ttl.Seconds()))
|
||||
for k, v := range results {
|
||||
pass := 0
|
||||
if v != nil {
|
||||
if *v {
|
||||
pass = 1
|
||||
} else {
|
||||
pass = -1
|
||||
// (1)
|
||||
// KEYS[1]: keyTs (policyPassReportedKey)
|
||||
// ARGV[1]: timestamp for "reported at"
|
||||
// ARGV[2]: ttl for the key
|
||||
scriptSrc := `
|
||||
redis.call('SET', KEYS[1], ARGV[1])
|
||||
return redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
`
|
||||
keyCount := 1
|
||||
args := make(redigo.Args, 0, 3)
|
||||
args = args.Add(keyTs, ts.Unix(), int(ttl.Seconds()))
|
||||
|
||||
if len(results) > 0 {
|
||||
// (2)
|
||||
// KEYS[1]: keyList (policyPassHostKey)
|
||||
// KEYS[2]: keyTs (policyPassReportedKey)
|
||||
// ARGV[1]: timestamp for "reported at"
|
||||
// ARGV[2]: max policy results to keep per host (list is trimmed to that size)
|
||||
// ARGV[3]: ttl for both keys
|
||||
// ARGV[4..]: policy_id=pass entries to LPUSH to the list
|
||||
keyCount = 2
|
||||
scriptSrc = `
|
||||
redis.call('LPUSH', KEYS[1], unpack(ARGV, 4))
|
||||
redis.call('LTRIM', KEYS[1], 0, ARGV[2])
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[3])
|
||||
redis.call('SET', KEYS[2], ARGV[1])
|
||||
return redis.call('EXPIRE', KEYS[2], ARGV[3])
|
||||
`
|
||||
// convert results to LPUSH arguments, store as policy_id=1 for pass,
|
||||
// policy_id=-1 for fail, policy_id=0 for null result.
|
||||
args = make(redigo.Args, 0, 5+len(results))
|
||||
args = args.Add(keyList, keyTs, ts.Unix(), maxRedisPolicyResultsPerHost, int(ttl.Seconds()))
|
||||
for k, v := range results {
|
||||
pass := 0
|
||||
if v != nil {
|
||||
if *v {
|
||||
pass = 1
|
||||
} else {
|
||||
pass = -1
|
||||
}
|
||||
}
|
||||
args = args.Add(fmt.Sprintf("%d=%d", k, pass))
|
||||
}
|
||||
args = args.Add(fmt.Sprintf("%d=%d", k, pass))
|
||||
}
|
||||
|
||||
script := redigo.NewScript(keyCount, scriptSrc)
|
||||
|
||||
conn := t.pool.Get()
|
||||
defer conn.Close()
|
||||
if err := redis.BindConn(t.pool, conn, keyList, keyTs); err != nil {
|
||||
|
|
|
|||
|
|
@ -410,6 +410,115 @@ func testRecordPolicyQueryExecutionsAsync(t *testing.T, ds *mock.Store, pool fle
|
|||
require.Equal(t, 0, count)
|
||||
}
|
||||
|
||||
func testRecordPolicyQueryExecutionsNoPoliciesSync(t *testing.T, ds *mock.Store, pool fleet.RedisPool) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
lastYear := now.Add(-365 * 24 * time.Hour)
|
||||
host := &fleet.Host{
|
||||
ID: 1,
|
||||
Platform: "linux",
|
||||
PolicyUpdatedAt: lastYear,
|
||||
}
|
||||
|
||||
var emptyResults map[uint]*bool
|
||||
keyList, keyTs := fmt.Sprintf(policyPassHostKey, host.ID), fmt.Sprintf(policyPassReportedKey, host.ID)
|
||||
|
||||
task := NewTask(ds, pool, clock.C, config.OsqueryConfig{})
|
||||
|
||||
policyReportedAt := task.GetHostPolicyReportedAt(ctx, host)
|
||||
require.True(t, policyReportedAt.Equal(lastYear))
|
||||
|
||||
err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false)
|
||||
require.NoError(t, err)
|
||||
require.True(t, ds.RecordPolicyQueryExecutionsFuncInvoked)
|
||||
ds.RecordPolicyQueryExecutionsFuncInvoked = false
|
||||
|
||||
conn := redis.ConfigureDoer(pool, pool.Get())
|
||||
defer conn.Close()
|
||||
|
||||
n, err := redigo.Int(conn.Do("EXISTS", keyList))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, n)
|
||||
|
||||
n, err = redigo.Int(conn.Do("EXISTS", keyTs))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, n)
|
||||
|
||||
n, err = redigo.Int(conn.Do("ZCARD", policyPassHostIDsKey))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, n)
|
||||
|
||||
policyReportedAt = task.GetHostPolicyReportedAt(ctx, host)
|
||||
require.True(t, policyReportedAt.Equal(now))
|
||||
}
|
||||
|
||||
func testRecordPolicyQueryExecutionsNoPoliciesAsync(t *testing.T, ds *mock.Store, pool fleet.RedisPool) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
lastYear := now.Add(-365 * 24 * time.Hour)
|
||||
host := &fleet.Host{
|
||||
ID: 1,
|
||||
Platform: "linux",
|
||||
PolicyUpdatedAt: lastYear,
|
||||
}
|
||||
var emptyResults map[uint]*bool
|
||||
keyList, keyTs := fmt.Sprintf(policyPassHostKey, host.ID), fmt.Sprintf(policyPassReportedKey, host.ID)
|
||||
|
||||
task := NewTask(ds, pool, clock.C, config.OsqueryConfig{
|
||||
EnableAsyncHostProcessing: "true",
|
||||
AsyncHostInsertBatch: 3,
|
||||
AsyncHostUpdateBatch: 3,
|
||||
AsyncHostDeleteBatch: 3,
|
||||
AsyncHostRedisPopCount: 3,
|
||||
AsyncHostRedisScanKeysCount: 10,
|
||||
})
|
||||
|
||||
policyReportedAt := task.GetHostPolicyReportedAt(ctx, host)
|
||||
require.True(t, policyReportedAt.Equal(lastYear))
|
||||
|
||||
err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, ds.RecordPolicyQueryExecutionsFuncInvoked)
|
||||
|
||||
conn := redis.ConfigureDoer(pool, pool.Get())
|
||||
defer conn.Close()
|
||||
defer conn.Do("DEL", keyTs) //nolint:errcheck
|
||||
|
||||
n, err := redigo.Int(conn.Do("EXISTS", keyList))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, n)
|
||||
|
||||
ts, err := redigo.Int64(conn.Do("GET", keyTs))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, now.Unix(), ts)
|
||||
|
||||
count, err := redigo.Int(conn.Do("ZCARD", policyPassHostIDsKey))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
tsActive, err := redigo.Int64(conn.Do("ZSCORE", policyPassHostIDsKey, host.ID))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tsActive, ts)
|
||||
|
||||
policyReportedAt = task.GetHostPolicyReportedAt(ctx, host)
|
||||
// because we transition via unix epoch (seconds), not exactly equal
|
||||
require.WithinDuration(t, now, policyReportedAt, time.Second)
|
||||
// host's PolicyUpdatedAt field hasn't been updated yet, because the policy
|
||||
// results are in redis, not in mysql yet.
|
||||
require.True(t, host.PolicyUpdatedAt.Equal(lastYear))
|
||||
|
||||
// running the collector removes the host from the active set
|
||||
var stats collectorExecStats
|
||||
err = task.collectPolicyQueryExecutions(ctx, ds, pool, &stats)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, stats.Keys)
|
||||
require.Equal(t, 0, stats.Items)
|
||||
require.False(t, stats.Failed)
|
||||
|
||||
count, err = redigo.Int(conn.Do("ZCARD", policyPassHostIDsKey))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, count)
|
||||
}
|
||||
|
||||
func createPolicies(t *testing.T, ds *mysql.Datastore, count int) []uint {
|
||||
ctx := context.Background()
|
||||
|
||||
|
|
|
|||
|
|
@ -126,12 +126,16 @@ func TestRecord(t *testing.T) {
|
|||
pool := redistest.SetupRedis(t, "policy_pass", false, false, false)
|
||||
t.Run("sync", func(t *testing.T) { testRecordPolicyQueryExecutionsSync(t, ds, pool) })
|
||||
t.Run("async", func(t *testing.T) { testRecordPolicyQueryExecutionsAsync(t, ds, pool) })
|
||||
t.Run("sync", func(t *testing.T) { testRecordPolicyQueryExecutionsNoPoliciesSync(t, ds, pool) })
|
||||
t.Run("async", func(t *testing.T) { testRecordPolicyQueryExecutionsNoPoliciesAsync(t, ds, pool) })
|
||||
})
|
||||
|
||||
t.Run("cluster", func(t *testing.T) {
|
||||
pool := redistest.SetupRedis(t, "policy_pass", true, true, false)
|
||||
t.Run("sync", func(t *testing.T) { testRecordPolicyQueryExecutionsSync(t, ds, pool) })
|
||||
t.Run("async", func(t *testing.T) { testRecordPolicyQueryExecutionsAsync(t, ds, pool) })
|
||||
t.Run("sync", func(t *testing.T) { testRecordPolicyQueryExecutionsNoPoliciesSync(t, ds, pool) })
|
||||
t.Run("async", func(t *testing.T) { testRecordPolicyQueryExecutionsNoPoliciesAsync(t, ds, pool) })
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -614,13 +614,18 @@ func (svc *Service) GetDistributedQueries(ctx context.Context) (queries map[stri
|
|||
}
|
||||
}
|
||||
|
||||
policyQueries, err := svc.policyQueriesForHost(ctx, host)
|
||||
policyQueries, noPolicies, err := svc.policyQueriesForHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, nil, 0, newOsqueryError(err.Error())
|
||||
}
|
||||
for name, query := range policyQueries {
|
||||
queries[hostPolicyQueryPrefix+name] = query
|
||||
}
|
||||
if noPolicies {
|
||||
// This is only set when it's time to re-run policies on the host,
|
||||
// but the host doesn't have any policies assigned.
|
||||
queries[hostNoPoliciesWildcard] = alwaysTrueQuery
|
||||
}
|
||||
|
||||
accelerate = uint(0)
|
||||
if host.Hostname == "" || host.Platform == "" {
|
||||
|
|
@ -744,16 +749,22 @@ func (svc *Service) labelQueriesForHost(ctx context.Context, host *fleet.Host) (
|
|||
return labelQueries, nil
|
||||
}
|
||||
|
||||
func (svc *Service) policyQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) {
|
||||
// policyQueriesForHost returns policy queries if it's the time to re-run policies on the given host.
|
||||
// It returns (nil, true, nil) if the interval is so that policies should be executed on the host, but there are no policies
|
||||
// assigned to such host.
|
||||
func (svc *Service) policyQueriesForHost(ctx context.Context, host *fleet.Host) (policyQueries map[string]string, noPoliciesForHost bool, err error) {
|
||||
policyReportedAt := svc.task.GetHostPolicyReportedAt(ctx, host)
|
||||
if !svc.shouldUpdate(policyReportedAt, svc.config.Osquery.PolicyUpdateInterval, host.ID) && !host.RefetchRequested {
|
||||
return nil, nil
|
||||
return nil, false, nil
|
||||
}
|
||||
policyQueries, err := svc.ds.PolicyQueriesForHost(ctx, host)
|
||||
policyQueries, err = svc.ds.PolicyQueriesForHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "retrieve policy queries")
|
||||
return nil, false, ctxerr.Wrap(ctx, err, "retrieve policy queries")
|
||||
}
|
||||
return policyQueries, nil
|
||||
if len(policyQueries) == 0 {
|
||||
return nil, true, nil
|
||||
}
|
||||
return policyQueries, false, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
@ -867,6 +878,15 @@ const (
|
|||
// osqueryd writes the distributed query results.
|
||||
hostPolicyQueryPrefix = "fleet_policy_query_"
|
||||
|
||||
// hostNoPoliciesWildcard is a query sent to hosts when it's time to run policy
|
||||
// queries on a host, but such host does not have any policies assigned.
|
||||
// When Fleet receives results from such query then it will update the host's
|
||||
// policy_updated_at column.
|
||||
//
|
||||
// This is used to prevent hosts without policies assigned to continuously
|
||||
// perform lookups in the policies table on every check in.
|
||||
hostNoPoliciesWildcard = "fleet_no_policies_wildcard"
|
||||
|
||||
// hostDistributedQueryPrefix is appended before the query name when a query is
|
||||
// run from a distributed query campaign
|
||||
hostDistributedQueryPrefix = "fleet_distributed_query_"
|
||||
|
|
@ -895,7 +915,15 @@ func (svc *Service) SubmitDistributedQueryResults(
|
|||
|
||||
svc.maybeDebugHost(ctx, host, results, statuses, messages)
|
||||
|
||||
var hostWithoutPolicies bool
|
||||
for query, rows := range results {
|
||||
// When receiving this query in the results, we will update the host's
|
||||
// policy_updated_at column.
|
||||
if query == hostNoPoliciesWildcard {
|
||||
hostWithoutPolicies = true
|
||||
continue
|
||||
}
|
||||
|
||||
// osquery docs say any nonzero (string) value for status indicates a query error
|
||||
status, ok := statuses[query]
|
||||
failed := ok && status != fleet.StatusOK
|
||||
|
|
@ -973,6 +1001,13 @@ func (svc *Service) SubmitDistributedQueryResults(
|
|||
if err := svc.task.RecordPolicyQueryExecutions(ctx, host, policyResults, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil {
|
||||
logging.WithErr(ctx, err)
|
||||
}
|
||||
} else {
|
||||
if hostWithoutPolicies {
|
||||
// RecordPolicyQueryExecutions called with results=nil will still update the host's policy_updated_at column.
|
||||
if err := svc.task.RecordPolicyQueryExecutions(ctx, host, nil, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil {
|
||||
logging.WithErr(ctx, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if additionalUpdated {
|
||||
|
|
|
|||
|
|
@ -847,7 +847,8 @@ func TestLabelQueries(t *testing.T) {
|
|||
// should be turned on so that we can quickly fill labels)
|
||||
queries, discovery, acc, err := svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform)), len(queries), distQueriesMapKeys(queries))
|
||||
// +1 for the fleet_no_policies_wildcard query.
|
||||
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+1, len(queries), distQueriesMapKeys(queries))
|
||||
verifyDiscovery(t, queries, discovery)
|
||||
assert.NotZero(t, acc)
|
||||
|
||||
|
|
@ -858,7 +859,7 @@ func TestLabelQueries(t *testing.T) {
|
|||
|
||||
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, queries)
|
||||
require.Len(t, queries, 1) // fleet_no_policies_wildcard query
|
||||
verifyDiscovery(t, queries, discovery)
|
||||
assert.Zero(t, acc)
|
||||
|
||||
|
|
@ -873,7 +874,8 @@ func TestLabelQueries(t *testing.T) {
|
|||
// Now we should get the label queries
|
||||
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, 3)
|
||||
// +1 for the fleet_no_policies_wildcard query.
|
||||
require.Len(t, queries, 3+1)
|
||||
verifyDiscovery(t, queries, discovery)
|
||||
assert.Zero(t, acc)
|
||||
|
||||
|
|
@ -928,7 +930,7 @@ func TestLabelQueries(t *testing.T) {
|
|||
ctx = hostctx.NewContext(ctx, host)
|
||||
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, queries)
|
||||
require.Len(t, queries, 1) // fleet_no_policies_wildcard query
|
||||
verifyDiscovery(t, queries, discovery)
|
||||
assert.Zero(t, acc)
|
||||
|
||||
|
|
@ -937,8 +939,8 @@ func TestLabelQueries(t *testing.T) {
|
|||
ctx = hostctx.NewContext(ctx, host)
|
||||
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
// +3 for label queries
|
||||
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+3, len(queries), distQueriesMapKeys(queries))
|
||||
// +3 for label queries, +1 for the fleet_no_policies_wildcard query.
|
||||
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+3+1, len(queries), distQueriesMapKeys(queries))
|
||||
verifyDiscovery(t, queries, discovery)
|
||||
assert.Zero(t, acc)
|
||||
|
||||
|
|
@ -967,7 +969,7 @@ func TestLabelQueries(t *testing.T) {
|
|||
ctx = hostctx.NewContext(ctx, host)
|
||||
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, queries)
|
||||
require.Len(t, queries, 1) // fleet_no_policies_wildcard query
|
||||
verifyDiscovery(t, queries, discovery)
|
||||
assert.Zero(t, acc)
|
||||
}
|
||||
|
|
@ -1006,8 +1008,8 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) {
|
|||
// queries)
|
||||
queries, discovery, acc, err := svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
// +1 due to 'windows_update_history'
|
||||
if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+1, len(queries)) {
|
||||
// +1 due to 'windows_update_history', +1 due to fleet_no_policies_wildcard query.
|
||||
if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+1+1, len(queries)) {
|
||||
// this is just to print the diff between the expected and actual query
|
||||
// keys when the count assertion fails, to help debugging - they are not
|
||||
// expected to match.
|
||||
|
|
@ -1141,7 +1143,7 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) {
|
|||
ctx = hostctx.NewContext(ctx, host)
|
||||
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, queries)
|
||||
require.Len(t, queries, 1) // fleet_no_policies_wildcard query
|
||||
verifyDiscovery(t, queries, discovery)
|
||||
assert.Zero(t, acc)
|
||||
|
||||
|
|
@ -1152,7 +1154,8 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
// somehow confusingly, the query response above changed the host's platform
|
||||
// from windows to darwin
|
||||
require.Equal(t, len(expectedDetailQueriesForPlatform(gotHost.Platform)), len(queries), distQueriesMapKeys(queries))
|
||||
// +1 due to fleet_no_policies_wildcard query.
|
||||
require.Equal(t, len(expectedDetailQueriesForPlatform(gotHost.Platform))+1, len(queries), distQueriesMapKeys(queries))
|
||||
verifyDiscovery(t, queries, discovery)
|
||||
assert.Zero(t, acc)
|
||||
}
|
||||
|
|
@ -1215,8 +1218,8 @@ func TestDetailQueries(t *testing.T) {
|
|||
// queries)
|
||||
queries, discovery, acc, err := svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
// +1 for software inventory
|
||||
if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+1, len(queries)) {
|
||||
// +1 for software inventory, +1 for fleet_no_policies_wildcard
|
||||
if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+1+1, len(queries)) {
|
||||
// this is just to print the diff between the expected and actual query
|
||||
// keys when the count assertion fails, to help debugging - they are not
|
||||
// expected to match.
|
||||
|
|
@ -1468,7 +1471,7 @@ func TestDetailQueries(t *testing.T) {
|
|||
ctx = hostctx.NewContext(ctx, host)
|
||||
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, queries)
|
||||
require.Len(t, queries, 1) // fleet_no_policies_wildcard query
|
||||
verifyDiscovery(t, queries, discovery)
|
||||
assert.Zero(t, acc)
|
||||
|
||||
|
|
@ -1477,8 +1480,8 @@ func TestDetailQueries(t *testing.T) {
|
|||
|
||||
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
// +1 software inventory
|
||||
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+1, len(queries), distQueriesMapKeys(queries))
|
||||
// +1 software inventory, +1 fleet_no_policies_wildcard query
|
||||
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+1+1, len(queries), distQueriesMapKeys(queries))
|
||||
verifyDiscovery(t, queries, discovery)
|
||||
assert.Zero(t, acc)
|
||||
}
|
||||
|
|
@ -1677,8 +1680,8 @@ func TestDistributedQueryResults(t *testing.T) {
|
|||
// Now we should get the active distributed query
|
||||
queries, discovery, acc, err := svc.GetDistributedQueries(hostCtx)
|
||||
require.NoError(t, err)
|
||||
// +1 for the distributed query for campaign ID 42, +1 for windows update history
|
||||
if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+2, len(queries)) {
|
||||
// +1 for the distributed query for campaign ID 42, +1 for windows update history, +1 for the fleet_no_policies_wildcard query.
|
||||
if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+3, len(queries)) {
|
||||
// this is just to print the diff between the expected and actual query
|
||||
// keys when the count assertion fails, to help debugging - they are not
|
||||
// expected to match.
|
||||
|
|
@ -2973,7 +2976,8 @@ func TestLiveQueriesFailing(t *testing.T) {
|
|||
|
||||
queries, discovery, _, err := svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform)), len(queries), distQueriesMapKeys(queries))
|
||||
// +1 to account for the fleet_no_policies_wildcard query.
|
||||
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+1, len(queries), distQueriesMapKeys(queries))
|
||||
verifyDiscovery(t, queries, discovery)
|
||||
|
||||
logs, err := io.ReadAll(buf)
|
||||
|
|
|
|||
Loading…
Reference in a new issue