From 9142c5de7938b09ee1004ee02609deb0731639bf Mon Sep 17 00:00:00 2001 From: Lucas Manuel Rodriguez Date: Thu, 31 Aug 2023 10:58:50 -0300 Subject: [PATCH] Prevent thundering herd when applying large number of policies on large number of hosts (#13552) #13527 (Adding @mna to double check the changes in the async implementation of policy result storage) This PR also adds the osquery-perf changes needed to define the count of macOS and Windows hosts. - [X] Changes file added for user-visible changes in `changes/` or `orbit/changes/`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - ~[ ] Documented any API changes (docs/Using-Fleet/REST-API.md or docs/Contributing/API-for-contributors.md)~ - ~[ ] Documented any permissions changes (docs/Using Fleet/manage-access.md)~ - [X] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [X] Added support on fleet's osquery simulator `cmd/osquery-perf` for new osquery data ingestion features. - [X] Added/updated tests - [X] Manual QA for all new/changed functionality - ~For Orbit and Fleet Desktop changes:~ - ~[ ] Manual QA must be performed in the three main OSs, macOS, Windows and Linux.~ - ~[ ] Auto-update manual QA, from released version of component to new version (see [tools/tuf/test](../tools/tuf/test/README.md)).~ Test with 80k hosts: 70k simulated macOS, 10k simulated Windows. Apply Windows policies first, then apply macOS policies: ``` fleetctl apply -f ee/cis/win-10/cis-policy-queries.yml # Leave running for some time fleetctl apply -f ee/cis/macos-13/cis-policy-queries.yml ``` After applying CIS policies previous to these changes: ![Screenshot 2023-08-23 at 11 36 18](https://github.com/fleetdm/fleet/assets/2073526/72c1dc7d-e601-4248-be35-93c85b749f5d) After applying these changes and applying the same policies: ![Screenshot 2023-08-28 at 15 42 57](https://github.com/fleetdm/fleet/assets/2073526/6b6d76b8-6acb-4893-a913-bf603a68f1a4) --- changes/13527-applying-policies-at-scale | 4 + cmd/osquery-perf/agent.go | 47 +++++++++- server/datastore/mysql/hosts_test.go | 50 ++++++++++ server/datastore/mysql/policies.go | 50 +++++----- server/fleet/datastore.go | 1 + server/service/async/async_policy.go | 77 +++++++++------ server/service/async/async_policy_test.go | 109 ++++++++++++++++++++++ server/service/async/async_test.go | 4 + server/service/osquery.go | 47 ++++++++-- server/service/osquery_test.go | 42 +++++---- 10 files changed, 348 insertions(+), 83 deletions(-) create mode 100644 changes/13527-applying-policies-at-scale diff --git a/changes/13527-applying-policies-at-scale b/changes/13527-applying-policies-at-scale new file mode 100644 index 0000000000..afbb2d448d --- /dev/null +++ b/changes/13527-applying-policies-at-scale @@ -0,0 +1,4 @@ +* Improved performance at scale when applying hundreds of policies to thousands of hosts via `fleetctl apply`. +IMPORTANT: In previous versions of Fleet there's a performance issue (thundering herd) when applying hundreds of +policies on a large number of hosts. To avoid this, make sure to deploy this version of Fleet, and make sure Fleet +is running for at least 1h (or the configured `FLEET_OSQUERY_POLICY_UPDATE_INTERVAL`) before applying the policies. diff --git a/cmd/osquery-perf/agent.go b/cmd/osquery-perf/agent.go index ae247b4699..2cd746725d 100644 --- a/cmd/osquery-perf/agent.go +++ b/cmd/osquery-perf/agent.go @@ -1406,8 +1406,10 @@ func main() { orbitProb = flag.Float64("orbit_prob", 0.5, "Probability of a host being identified as orbit install [0, 1]") munkiIssueProb = flag.Float64("munki_issue_prob", 0.5, "Probability of a host having munki issues (note that ~50% of hosts have munki installed) [0, 1]") munkiIssueCount = flag.Int("munki_issue_count", 10, "Number of munki issues reported by hosts identified to have munki issues") - osTemplates = flag.String("os_templates", "mac10.14.6", fmt.Sprintf("Comma separated list of host OS templates to use (any of %v, with or without the .tmpl extension)", allowedTemplateNames)) - emptySerialProb = flag.Float64("empty_serial_prob", 0.1, "Probability of a host having no serial number [0, 1]") + // E.g. when running with `-host_count=10`, you can set host count for each template the following way: + // `-os_templates=windows_11.tmpl:3,mac10.14.6.tmpl:4,ubuntu_22.04.tmpl:3` + osTemplates = flag.String("os_templates", "mac10.14.6", fmt.Sprintf("Comma separated list of host OS templates to use and optionally their host count separated by ':' (any of %v, with or without the .tmpl extension)", allowedTemplateNames)) + emptySerialProb = flag.Float64("empty_serial_prob", 0.1, "Probability of a host having no serial number [0, 1]") mdmProb = flag.Float64("mdm_prob", 0.0, "Probability of a host enrolling via MDM (for macOS) [0, 1]") mdmSCEPChallenge = flag.String("mdm_scep_challenge", "", "SCEP challenge to use when running MDM enroll") @@ -1434,9 +1436,20 @@ func main() { log.Fatalf("Argument unique_software_uninstall_count cannot be bigger than unique_software_count") } - var tmpls []*template.Template + tmplsm := make(map[*template.Template]int) requestedTemplates := strings.Split(*osTemplates, ",") + tmplsTotalHostCount := 0 for _, nm := range requestedTemplates { + numberOfHosts := 0 + if strings.Contains(nm, ":") { + parts := strings.Split(nm, ":") + nm = parts[0] + hc, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + log.Fatalf("Invalid template host count: %s", parts[1]) + } + numberOfHosts = int(hc) + } if !strings.HasSuffix(nm, ".tmpl") { nm += ".tmpl" } @@ -1448,7 +1461,11 @@ func main() { if err != nil { log.Fatal("parse templates: ", err) } - tmpls = append(tmpls, tmpl) + tmplsm[tmpl] = numberOfHosts + tmplsTotalHostCount += numberOfHosts + } + if tmplsTotalHostCount != 0 && tmplsTotalHostCount != *hostCount { + log.Fatalf("Invalid host count in templates: total=%d vs host_count=%d", tmplsTotalHostCount, *hostCount) } // Spread starts over the interval to prevent thundering herd @@ -1463,8 +1480,28 @@ func main() { nodeKeyManager.LoadKeys() } + var tmplss []*template.Template + for tmpl := range tmplsm { + tmplss = append(tmplss, tmpl) + } + for i := 0; i < *hostCount; i++ { - tmpl := tmpls[i%len(tmpls)] + var tmpl *template.Template + if tmplsTotalHostCount > 0 { + for tmpl_, hostCount := range tmplsm { + if hostCount > 0 { + tmpl = tmpl_ + tmplsm[tmpl_] = tmplsm[tmpl_] - 1 + break + } + } + if tmpl == nil { + log.Fatalf("Failed to determine template for host: %d", i) + } + } else { + tmpl = tmplss[i%len(tmplss)] + } + a := newAgent(i+1, *serverURL, *enrollSecret, diff --git a/server/datastore/mysql/hosts_test.go b/server/datastore/mysql/hosts_test.go index c675e7a48a..5bc9f460d2 100644 --- a/server/datastore/mysql/hosts_test.go +++ b/server/datastore/mysql/hosts_test.go @@ -135,6 +135,7 @@ func TestHosts(t *testing.T) { {"ReplaceHostBatteries", testHostsReplaceHostBatteries}, {"CountHostsNotResponding", testCountHostsNotResponding}, {"FailingPoliciesCount", testFailingPoliciesCount}, + {"HostRecordNoPolicies", testHostsRecordNoPolicies}, {"SetOrUpdateHostDisksSpace", testHostsSetOrUpdateHostDisksSpace}, {"HostIDsByOSID", testHostIDsByOSID}, {"SetOrUpdateHostDisksEncryption", testHostsSetOrUpdateHostDisksEncryption}, @@ -6159,6 +6160,55 @@ func testFailingPoliciesCount(t *testing.T, ds *Datastore) { }) } +func testHostsRecordNoPolicies(t *testing.T, ds *Datastore) { + initialTime := time.Now() + + for i := 0; i < 2; i++ { + _, err := ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: initialTime, + LabelUpdatedAt: initialTime, + PolicyUpdatedAt: initialTime, + SeenTime: initialTime.Add(-time.Duration(i) * time.Minute), + OsqueryHostID: ptr.String(strconv.Itoa(i)), + NodeKey: ptr.String(fmt.Sprintf("%d", i)), + UUID: fmt.Sprintf("%d", i), + Hostname: fmt.Sprintf("foo.local%d", i), + }) + require.NoError(t, err) + } + + filter := fleet.TeamFilter{User: test.UserAdmin} + + hosts := listHostsCheckCount(t, ds, filter, fleet.HostListOptions{}, 2) + require.Len(t, hosts, 2) + + h1 := hosts[0] + h2 := hosts[1] + + assert.WithinDuration(t, initialTime, h1.PolicyUpdatedAt, 1*time.Second) + assert.Zero(t, h1.HostIssues.FailingPoliciesCount) + assert.Zero(t, h1.HostIssues.TotalIssuesCount) + assert.WithinDuration(t, initialTime, h2.PolicyUpdatedAt, 1*time.Second) + assert.Zero(t, h2.HostIssues.FailingPoliciesCount) + assert.Zero(t, h2.HostIssues.TotalIssuesCount) + + policyUpdatedAt := initialTime.Add(1 * time.Hour) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, nil, policyUpdatedAt, false)) + + hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{}, 2) + require.Len(t, hosts, 2) + + h1 = hosts[0] + h2 = hosts[1] + + assert.WithinDuration(t, policyUpdatedAt, h1.PolicyUpdatedAt, 1*time.Second) + assert.Zero(t, h1.HostIssues.FailingPoliciesCount) + assert.Zero(t, h1.HostIssues.TotalIssuesCount) + assert.WithinDuration(t, initialTime, h2.PolicyUpdatedAt, 1*time.Second) + assert.Zero(t, h2.HostIssues.FailingPoliciesCount) + assert.Zero(t, h2.HostIssues.TotalIssuesCount) +} + func testHostsSetOrUpdateHostDisksSpace(t *testing.T, ds *Datastore) { host, err := ds.NewHost(context.Background(), &fleet.Host{ DetailUpdatedAt: time.Now(), diff --git a/server/datastore/mysql/policies.go b/server/datastore/mysql/policies.go index a7ad822222..df3e484282 100644 --- a/server/datastore/mysql/policies.go +++ b/server/datastore/mysql/policies.go @@ -214,21 +214,23 @@ func filterNotExecuted(results map[uint]*bool) map[uint]bool { } func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error { - // Sort the results to have generated SQL queries ordered to minimize - // deadlocks. See https://github.com/fleetdm/fleet/issues/1146. - orderedIDs := make([]uint, 0, len(results)) - for policyID := range results { - orderedIDs = append(orderedIDs, policyID) - } - sort.Slice(orderedIDs, func(i, j int) bool { return orderedIDs[i] < orderedIDs[j] }) - - // Loop through results, collecting which labels we need to insert/update vals := []interface{}{} bindvars := []string{} - for _, policyID := range orderedIDs { - matches := results[policyID] - bindvars = append(bindvars, "(?,?,?,?)") - vals = append(vals, updated, policyID, host.ID, matches) + if len(results) > 0 { + // Sort the results to have generated SQL queries ordered to minimize + // deadlocks. See https://github.com/fleetdm/fleet/issues/1146. + orderedIDs := make([]uint, 0, len(results)) + for policyID := range results { + orderedIDs = append(orderedIDs, policyID) + } + sort.Slice(orderedIDs, func(i, j int) bool { return orderedIDs[i] < orderedIDs[j] }) + + // Loop through results, collecting which labels we need to insert/update + for _, policyID := range orderedIDs { + matches := results[policyID] + bindvars = append(bindvars, "(?,?,?,?)") + vals = append(vals, updated, policyID, host.ID, matches) + } } // NOTE: the insert of policy membership that follows must be kept in sync @@ -238,16 +240,17 @@ func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *flee // semantically equivalent, even though here it processes a single host and // in async mode it processes a batch of hosts). - query := fmt.Sprintf( - `INSERT INTO policy_membership (updated_at, policy_id, host_id, passes) - VALUES %s ON DUPLICATE KEY UPDATE updated_at=VALUES(updated_at), passes=VALUES(passes)`, - strings.Join(bindvars, ","), - ) - err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { - _, err := tx.ExecContext(ctx, query, vals...) - if err != nil { - return ctxerr.Wrapf(ctx, err, "insert policy_membership (%v)", vals) + if len(results) > 0 { + query := fmt.Sprintf( + `INSERT INTO policy_membership (updated_at, policy_id, host_id, passes) + VALUES %s ON DUPLICATE KEY UPDATE updated_at=VALUES(updated_at), passes=VALUES(passes)`, + strings.Join(bindvars, ","), + ) + _, err := tx.ExecContext(ctx, query, vals...) + if err != nil { + return ctxerr.Wrapf(ctx, err, "insert policy_membership (%v)", vals) + } } // if we are deferring host updates, we return at this point and do the change outside of the tx @@ -255,8 +258,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *flee return nil } - _, err = tx.ExecContext(ctx, `UPDATE hosts SET policy_updated_at = ? WHERE id=?`, updated, host.ID) - if err != nil { + if _, err := tx.ExecContext(ctx, `UPDATE hosts SET policy_updated_at = ? WHERE id=?`, updated, host.ID); err != nil { return ctxerr.Wrap(ctx, err, "updating hosts policy updated at") } return nil diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index cda44ed4d5..3d3a999a77 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -667,6 +667,7 @@ type Datastore interface { FlippingPoliciesForHost(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error) // RecordPolicyQueryExecutions records the execution results of the policies for the given host. + // Even if `results` is empty, the host's `policy_updated_at` will be updated. RecordPolicyQueryExecutions(ctx context.Context, host *Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error // RecordLabelQueryExecutions saves the results of label queries. The results map is a map of label id -> whether or diff --git a/server/service/async/async_policy.go b/server/service/async/async_policy.go index 5dc5445272..a7a0e5b50d 100644 --- a/server/service/async/async_policy.go +++ b/server/service/async/async_policy.go @@ -22,10 +22,8 @@ const ( policyPassKeysMinTTL = 7 * 24 * time.Hour // 1 week ) -var ( - // redis list will be LTRIM'd if there are more policy IDs than this. - maxRedisPolicyResultsPerHost = 1000 -) +// redis list will be LTRIM'd if there are more policy IDs than this. +var maxRedisPolicyResultsPerHost = 1000 func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error { cfg := t.taskConfigs[config.AsyncTaskPolicyMembership] @@ -51,36 +49,57 @@ func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host ttl = maxTTL } - // KEYS[1]: keyList (policyPassHostKey) - // KEYS[2]: keyTs (policyPassReportedKey) - // ARGV[1]: timestamp for "reported at" - // ARGV[2]: max policy results to keep per host (list is trimmed to that size) - // ARGV[3]: ttl for both keys - // ARGV[4..]: policy_id=pass entries to LPUSH to the list - script := redigo.NewScript(2, ` - redis.call('LPUSH', KEYS[1], unpack(ARGV, 4)) - redis.call('LTRIM', KEYS[1], 0, ARGV[2]) - redis.call('EXPIRE', KEYS[1], ARGV[3]) - redis.call('SET', KEYS[2], ARGV[1]) - return redis.call('EXPIRE', KEYS[2], ARGV[3]) - `) + // There are two versions of the script (1) and (2). + // Script (1) is used when the are no policy results to report and + // script (2) is used when there are policy results. - // convert results to LPUSH arguments, store as policy_id=1 for pass, - // policy_id=-1 for fail, policy_id=0 for null result. - args := make(redigo.Args, 0, 5+len(results)) - args = args.Add(keyList, keyTs, ts.Unix(), maxRedisPolicyResultsPerHost, int(ttl.Seconds())) - for k, v := range results { - pass := 0 - if v != nil { - if *v { - pass = 1 - } else { - pass = -1 + // (1) + // KEYS[1]: keyTs (policyPassReportedKey) + // ARGV[1]: timestamp for "reported at" + // ARGV[2]: ttl for the key + scriptSrc := ` + redis.call('SET', KEYS[1], ARGV[1]) + return redis.call('EXPIRE', KEYS[1], ARGV[2]) + ` + keyCount := 1 + args := make(redigo.Args, 0, 3) + args = args.Add(keyTs, ts.Unix(), int(ttl.Seconds())) + + if len(results) > 0 { + // (2) + // KEYS[1]: keyList (policyPassHostKey) + // KEYS[2]: keyTs (policyPassReportedKey) + // ARGV[1]: timestamp for "reported at" + // ARGV[2]: max policy results to keep per host (list is trimmed to that size) + // ARGV[3]: ttl for both keys + // ARGV[4..]: policy_id=pass entries to LPUSH to the list + keyCount = 2 + scriptSrc = ` + redis.call('LPUSH', KEYS[1], unpack(ARGV, 4)) + redis.call('LTRIM', KEYS[1], 0, ARGV[2]) + redis.call('EXPIRE', KEYS[1], ARGV[3]) + redis.call('SET', KEYS[2], ARGV[1]) + return redis.call('EXPIRE', KEYS[2], ARGV[3]) + ` + // convert results to LPUSH arguments, store as policy_id=1 for pass, + // policy_id=-1 for fail, policy_id=0 for null result. + args = make(redigo.Args, 0, 5+len(results)) + args = args.Add(keyList, keyTs, ts.Unix(), maxRedisPolicyResultsPerHost, int(ttl.Seconds())) + for k, v := range results { + pass := 0 + if v != nil { + if *v { + pass = 1 + } else { + pass = -1 + } } + args = args.Add(fmt.Sprintf("%d=%d", k, pass)) } - args = args.Add(fmt.Sprintf("%d=%d", k, pass)) } + script := redigo.NewScript(keyCount, scriptSrc) + conn := t.pool.Get() defer conn.Close() if err := redis.BindConn(t.pool, conn, keyList, keyTs); err != nil { diff --git a/server/service/async/async_policy_test.go b/server/service/async/async_policy_test.go index 9f227bade0..3962a9f6cb 100644 --- a/server/service/async/async_policy_test.go +++ b/server/service/async/async_policy_test.go @@ -410,6 +410,115 @@ func testRecordPolicyQueryExecutionsAsync(t *testing.T, ds *mock.Store, pool fle require.Equal(t, 0, count) } +func testRecordPolicyQueryExecutionsNoPoliciesSync(t *testing.T, ds *mock.Store, pool fleet.RedisPool) { + ctx := context.Background() + now := time.Now() + lastYear := now.Add(-365 * 24 * time.Hour) + host := &fleet.Host{ + ID: 1, + Platform: "linux", + PolicyUpdatedAt: lastYear, + } + + var emptyResults map[uint]*bool + keyList, keyTs := fmt.Sprintf(policyPassHostKey, host.ID), fmt.Sprintf(policyPassReportedKey, host.ID) + + task := NewTask(ds, pool, clock.C, config.OsqueryConfig{}) + + policyReportedAt := task.GetHostPolicyReportedAt(ctx, host) + require.True(t, policyReportedAt.Equal(lastYear)) + + err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false) + require.NoError(t, err) + require.True(t, ds.RecordPolicyQueryExecutionsFuncInvoked) + ds.RecordPolicyQueryExecutionsFuncInvoked = false + + conn := redis.ConfigureDoer(pool, pool.Get()) + defer conn.Close() + + n, err := redigo.Int(conn.Do("EXISTS", keyList)) + require.NoError(t, err) + require.Equal(t, 0, n) + + n, err = redigo.Int(conn.Do("EXISTS", keyTs)) + require.NoError(t, err) + require.Equal(t, 0, n) + + n, err = redigo.Int(conn.Do("ZCARD", policyPassHostIDsKey)) + require.NoError(t, err) + require.Equal(t, 0, n) + + policyReportedAt = task.GetHostPolicyReportedAt(ctx, host) + require.True(t, policyReportedAt.Equal(now)) +} + +func testRecordPolicyQueryExecutionsNoPoliciesAsync(t *testing.T, ds *mock.Store, pool fleet.RedisPool) { + ctx := context.Background() + now := time.Now() + lastYear := now.Add(-365 * 24 * time.Hour) + host := &fleet.Host{ + ID: 1, + Platform: "linux", + PolicyUpdatedAt: lastYear, + } + var emptyResults map[uint]*bool + keyList, keyTs := fmt.Sprintf(policyPassHostKey, host.ID), fmt.Sprintf(policyPassReportedKey, host.ID) + + task := NewTask(ds, pool, clock.C, config.OsqueryConfig{ + EnableAsyncHostProcessing: "true", + AsyncHostInsertBatch: 3, + AsyncHostUpdateBatch: 3, + AsyncHostDeleteBatch: 3, + AsyncHostRedisPopCount: 3, + AsyncHostRedisScanKeysCount: 10, + }) + + policyReportedAt := task.GetHostPolicyReportedAt(ctx, host) + require.True(t, policyReportedAt.Equal(lastYear)) + + err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false) + require.NoError(t, err) + require.False(t, ds.RecordPolicyQueryExecutionsFuncInvoked) + + conn := redis.ConfigureDoer(pool, pool.Get()) + defer conn.Close() + defer conn.Do("DEL", keyTs) //nolint:errcheck + + n, err := redigo.Int(conn.Do("EXISTS", keyList)) + require.NoError(t, err) + require.Equal(t, 0, n) + + ts, err := redigo.Int64(conn.Do("GET", keyTs)) + require.NoError(t, err) + require.Equal(t, now.Unix(), ts) + + count, err := redigo.Int(conn.Do("ZCARD", policyPassHostIDsKey)) + require.NoError(t, err) + require.Equal(t, 1, count) + tsActive, err := redigo.Int64(conn.Do("ZSCORE", policyPassHostIDsKey, host.ID)) + require.NoError(t, err) + require.Equal(t, tsActive, ts) + + policyReportedAt = task.GetHostPolicyReportedAt(ctx, host) + // because we transition via unix epoch (seconds), not exactly equal + require.WithinDuration(t, now, policyReportedAt, time.Second) + // host's PolicyUpdatedAt field hasn't been updated yet, because the policy + // results are in redis, not in mysql yet. + require.True(t, host.PolicyUpdatedAt.Equal(lastYear)) + + // running the collector removes the host from the active set + var stats collectorExecStats + err = task.collectPolicyQueryExecutions(ctx, ds, pool, &stats) + require.NoError(t, err) + require.Equal(t, 1, stats.Keys) + require.Equal(t, 0, stats.Items) + require.False(t, stats.Failed) + + count, err = redigo.Int(conn.Do("ZCARD", policyPassHostIDsKey)) + require.NoError(t, err) + require.Equal(t, 0, count) +} + func createPolicies(t *testing.T, ds *mysql.Datastore, count int) []uint { ctx := context.Background() diff --git a/server/service/async/async_test.go b/server/service/async/async_test.go index 6291d79a74..3974f7d14a 100644 --- a/server/service/async/async_test.go +++ b/server/service/async/async_test.go @@ -126,12 +126,16 @@ func TestRecord(t *testing.T) { pool := redistest.SetupRedis(t, "policy_pass", false, false, false) t.Run("sync", func(t *testing.T) { testRecordPolicyQueryExecutionsSync(t, ds, pool) }) t.Run("async", func(t *testing.T) { testRecordPolicyQueryExecutionsAsync(t, ds, pool) }) + t.Run("sync", func(t *testing.T) { testRecordPolicyQueryExecutionsNoPoliciesSync(t, ds, pool) }) + t.Run("async", func(t *testing.T) { testRecordPolicyQueryExecutionsNoPoliciesAsync(t, ds, pool) }) }) t.Run("cluster", func(t *testing.T) { pool := redistest.SetupRedis(t, "policy_pass", true, true, false) t.Run("sync", func(t *testing.T) { testRecordPolicyQueryExecutionsSync(t, ds, pool) }) t.Run("async", func(t *testing.T) { testRecordPolicyQueryExecutionsAsync(t, ds, pool) }) + t.Run("sync", func(t *testing.T) { testRecordPolicyQueryExecutionsNoPoliciesSync(t, ds, pool) }) + t.Run("async", func(t *testing.T) { testRecordPolicyQueryExecutionsNoPoliciesAsync(t, ds, pool) }) }) }) diff --git a/server/service/osquery.go b/server/service/osquery.go index f4572f43e2..cc47a39cd4 100644 --- a/server/service/osquery.go +++ b/server/service/osquery.go @@ -614,13 +614,18 @@ func (svc *Service) GetDistributedQueries(ctx context.Context) (queries map[stri } } - policyQueries, err := svc.policyQueriesForHost(ctx, host) + policyQueries, noPolicies, err := svc.policyQueriesForHost(ctx, host) if err != nil { return nil, nil, 0, newOsqueryError(err.Error()) } for name, query := range policyQueries { queries[hostPolicyQueryPrefix+name] = query } + if noPolicies { + // This is only set when it's time to re-run policies on the host, + // but the host doesn't have any policies assigned. + queries[hostNoPoliciesWildcard] = alwaysTrueQuery + } accelerate = uint(0) if host.Hostname == "" || host.Platform == "" { @@ -744,16 +749,22 @@ func (svc *Service) labelQueriesForHost(ctx context.Context, host *fleet.Host) ( return labelQueries, nil } -func (svc *Service) policyQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) { +// policyQueriesForHost returns policy queries if it's the time to re-run policies on the given host. +// It returns (nil, true, nil) if the interval is so that policies should be executed on the host, but there are no policies +// assigned to such host. +func (svc *Service) policyQueriesForHost(ctx context.Context, host *fleet.Host) (policyQueries map[string]string, noPoliciesForHost bool, err error) { policyReportedAt := svc.task.GetHostPolicyReportedAt(ctx, host) if !svc.shouldUpdate(policyReportedAt, svc.config.Osquery.PolicyUpdateInterval, host.ID) && !host.RefetchRequested { - return nil, nil + return nil, false, nil } - policyQueries, err := svc.ds.PolicyQueriesForHost(ctx, host) + policyQueries, err = svc.ds.PolicyQueriesForHost(ctx, host) if err != nil { - return nil, ctxerr.Wrap(ctx, err, "retrieve policy queries") + return nil, false, ctxerr.Wrap(ctx, err, "retrieve policy queries") } - return policyQueries, nil + if len(policyQueries) == 0 { + return nil, true, nil + } + return policyQueries, false, nil } //////////////////////////////////////////////////////////////////////////////// @@ -867,6 +878,15 @@ const ( // osqueryd writes the distributed query results. hostPolicyQueryPrefix = "fleet_policy_query_" + // hostNoPoliciesWildcard is a query sent to hosts when it's time to run policy + // queries on a host, but such host does not have any policies assigned. + // When Fleet receives results from such query then it will update the host's + // policy_updated_at column. + // + // This is used to prevent hosts without policies assigned to continuously + // perform lookups in the policies table on every check in. + hostNoPoliciesWildcard = "fleet_no_policies_wildcard" + // hostDistributedQueryPrefix is appended before the query name when a query is // run from a distributed query campaign hostDistributedQueryPrefix = "fleet_distributed_query_" @@ -895,7 +915,15 @@ func (svc *Service) SubmitDistributedQueryResults( svc.maybeDebugHost(ctx, host, results, statuses, messages) + var hostWithoutPolicies bool for query, rows := range results { + // When receiving this query in the results, we will update the host's + // policy_updated_at column. + if query == hostNoPoliciesWildcard { + hostWithoutPolicies = true + continue + } + // osquery docs say any nonzero (string) value for status indicates a query error status, ok := statuses[query] failed := ok && status != fleet.StatusOK @@ -973,6 +1001,13 @@ func (svc *Service) SubmitDistributedQueryResults( if err := svc.task.RecordPolicyQueryExecutions(ctx, host, policyResults, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil { logging.WithErr(ctx, err) } + } else { + if hostWithoutPolicies { + // RecordPolicyQueryExecutions called with results=nil will still update the host's policy_updated_at column. + if err := svc.task.RecordPolicyQueryExecutions(ctx, host, nil, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil { + logging.WithErr(ctx, err) + } + } } if additionalUpdated { diff --git a/server/service/osquery_test.go b/server/service/osquery_test.go index e01e0b4256..55989974c7 100644 --- a/server/service/osquery_test.go +++ b/server/service/osquery_test.go @@ -847,7 +847,8 @@ func TestLabelQueries(t *testing.T) { // should be turned on so that we can quickly fill labels) queries, discovery, acc, err := svc.GetDistributedQueries(ctx) require.NoError(t, err) - require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform)), len(queries), distQueriesMapKeys(queries)) + // +1 for the fleet_no_policies_wildcard query. + require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+1, len(queries), distQueriesMapKeys(queries)) verifyDiscovery(t, queries, discovery) assert.NotZero(t, acc) @@ -858,7 +859,7 @@ func TestLabelQueries(t *testing.T) { queries, discovery, acc, err = svc.GetDistributedQueries(ctx) require.NoError(t, err) - require.Empty(t, queries) + require.Len(t, queries, 1) // fleet_no_policies_wildcard query verifyDiscovery(t, queries, discovery) assert.Zero(t, acc) @@ -873,7 +874,8 @@ func TestLabelQueries(t *testing.T) { // Now we should get the label queries queries, discovery, acc, err = svc.GetDistributedQueries(ctx) require.NoError(t, err) - require.Len(t, queries, 3) + // +1 for the fleet_no_policies_wildcard query. + require.Len(t, queries, 3+1) verifyDiscovery(t, queries, discovery) assert.Zero(t, acc) @@ -928,7 +930,7 @@ func TestLabelQueries(t *testing.T) { ctx = hostctx.NewContext(ctx, host) queries, discovery, acc, err = svc.GetDistributedQueries(ctx) require.NoError(t, err) - require.Empty(t, queries) + require.Len(t, queries, 1) // fleet_no_policies_wildcard query verifyDiscovery(t, queries, discovery) assert.Zero(t, acc) @@ -937,8 +939,8 @@ func TestLabelQueries(t *testing.T) { ctx = hostctx.NewContext(ctx, host) queries, discovery, acc, err = svc.GetDistributedQueries(ctx) require.NoError(t, err) - // +3 for label queries - require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+3, len(queries), distQueriesMapKeys(queries)) + // +3 for label queries, +1 for the fleet_no_policies_wildcard query. + require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+3+1, len(queries), distQueriesMapKeys(queries)) verifyDiscovery(t, queries, discovery) assert.Zero(t, acc) @@ -967,7 +969,7 @@ func TestLabelQueries(t *testing.T) { ctx = hostctx.NewContext(ctx, host) queries, discovery, acc, err = svc.GetDistributedQueries(ctx) require.NoError(t, err) - require.Empty(t, queries) + require.Len(t, queries, 1) // fleet_no_policies_wildcard query verifyDiscovery(t, queries, discovery) assert.Zero(t, acc) } @@ -1006,8 +1008,8 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) { // queries) queries, discovery, acc, err := svc.GetDistributedQueries(ctx) require.NoError(t, err) - // +1 due to 'windows_update_history' - if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+1, len(queries)) { + // +1 due to 'windows_update_history', +1 due to fleet_no_policies_wildcard query. + if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+1+1, len(queries)) { // this is just to print the diff between the expected and actual query // keys when the count assertion fails, to help debugging - they are not // expected to match. @@ -1141,7 +1143,7 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) { ctx = hostctx.NewContext(ctx, host) queries, discovery, acc, err = svc.GetDistributedQueries(ctx) require.NoError(t, err) - require.Empty(t, queries) + require.Len(t, queries, 1) // fleet_no_policies_wildcard query verifyDiscovery(t, queries, discovery) assert.Zero(t, acc) @@ -1152,7 +1154,8 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) { require.NoError(t, err) // somehow confusingly, the query response above changed the host's platform // from windows to darwin - require.Equal(t, len(expectedDetailQueriesForPlatform(gotHost.Platform)), len(queries), distQueriesMapKeys(queries)) + // +1 due to fleet_no_policies_wildcard query. + require.Equal(t, len(expectedDetailQueriesForPlatform(gotHost.Platform))+1, len(queries), distQueriesMapKeys(queries)) verifyDiscovery(t, queries, discovery) assert.Zero(t, acc) } @@ -1215,8 +1218,8 @@ func TestDetailQueries(t *testing.T) { // queries) queries, discovery, acc, err := svc.GetDistributedQueries(ctx) require.NoError(t, err) - // +1 for software inventory - if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+1, len(queries)) { + // +1 for software inventory, +1 for fleet_no_policies_wildcard + if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+1+1, len(queries)) { // this is just to print the diff between the expected and actual query // keys when the count assertion fails, to help debugging - they are not // expected to match. @@ -1468,7 +1471,7 @@ func TestDetailQueries(t *testing.T) { ctx = hostctx.NewContext(ctx, host) queries, discovery, acc, err = svc.GetDistributedQueries(ctx) require.NoError(t, err) - require.Empty(t, queries) + require.Len(t, queries, 1) // fleet_no_policies_wildcard query verifyDiscovery(t, queries, discovery) assert.Zero(t, acc) @@ -1477,8 +1480,8 @@ func TestDetailQueries(t *testing.T) { queries, discovery, acc, err = svc.GetDistributedQueries(ctx) require.NoError(t, err) - // +1 software inventory - require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+1, len(queries), distQueriesMapKeys(queries)) + // +1 software inventory, +1 fleet_no_policies_wildcard query + require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+1+1, len(queries), distQueriesMapKeys(queries)) verifyDiscovery(t, queries, discovery) assert.Zero(t, acc) } @@ -1677,8 +1680,8 @@ func TestDistributedQueryResults(t *testing.T) { // Now we should get the active distributed query queries, discovery, acc, err := svc.GetDistributedQueries(hostCtx) require.NoError(t, err) - // +1 for the distributed query for campaign ID 42, +1 for windows update history - if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+2, len(queries)) { + // +1 for the distributed query for campaign ID 42, +1 for windows update history, +1 for the fleet_no_policies_wildcard query. + if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+3, len(queries)) { // this is just to print the diff between the expected and actual query // keys when the count assertion fails, to help debugging - they are not // expected to match. @@ -2973,7 +2976,8 @@ func TestLiveQueriesFailing(t *testing.T) { queries, discovery, _, err := svc.GetDistributedQueries(ctx) require.NoError(t, err) - require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform)), len(queries), distQueriesMapKeys(queries)) + // +1 to account for the fleet_no_policies_wildcard query. + require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+1, len(queries), distQueriesMapKeys(queries)) verifyDiscovery(t, queries, discovery) logs, err := io.ReadAll(buf)