Prevent thundering herd when applying large number of policies on large number of hosts (#13552)

#13527

(Adding @mna to double check the changes in the async implementation of
policy result storage)

This PR also adds the osquery-perf changes needed to define the count of
macOS and Windows hosts.

- [X] Changes file added for user-visible changes in `changes/` or
`orbit/changes/`.
See [Changes
files](https://fleetdm.com/docs/contributing/committing-changes#changes-files)
for more information.
- ~[ ] Documented any API changes (docs/Using-Fleet/REST-API.md or
docs/Contributing/API-for-contributors.md)~
- ~[ ] Documented any permissions changes (docs/Using
Fleet/manage-access.md)~
- [X] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)
- [X] Added support on fleet's osquery simulator `cmd/osquery-perf` for
new osquery data ingestion features.
- [X] Added/updated tests
- [X] Manual QA for all new/changed functionality
  - ~For Orbit and Fleet Desktop changes:~
- ~[ ] Manual QA must be performed in the three main OSs, macOS, Windows
and Linux.~
- ~[ ] Auto-update manual QA, from released version of component to new
version (see [tools/tuf/test](../tools/tuf/test/README.md)).~

Test with 80k hosts: 70k simulated macOS, 10k simulated Windows.
Apply Windows policies first, then apply macOS policies:
```
fleetctl apply -f ee/cis/win-10/cis-policy-queries.yml

# Leave running for some time

fleetctl apply -f ee/cis/macos-13/cis-policy-queries.yml
```

After applying CIS policies previous to these changes:
![Screenshot 2023-08-23 at 11 36
18](https://github.com/fleetdm/fleet/assets/2073526/72c1dc7d-e601-4248-be35-93c85b749f5d)

After applying these changes and applying the same policies:
![Screenshot 2023-08-28 at 15 42
57](https://github.com/fleetdm/fleet/assets/2073526/6b6d76b8-6acb-4893-a913-bf603a68f1a4)
This commit is contained in:
Lucas Manuel Rodriguez 2023-08-31 10:58:50 -03:00 committed by GitHub
parent 6637ea6517
commit 9142c5de79
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 348 additions and 83 deletions

View file

@ -0,0 +1,4 @@
* Improved performance at scale when applying hundreds of policies to thousands of hosts via `fleetctl apply`.
IMPORTANT: In previous versions of Fleet there's a performance issue (thundering herd) when applying hundreds of
policies on a large number of hosts. To avoid this, make sure to deploy this version of Fleet, and make sure Fleet
is running for at least 1h (or the configured `FLEET_OSQUERY_POLICY_UPDATE_INTERVAL`) before applying the policies.

View file

@ -1406,8 +1406,10 @@ func main() {
orbitProb = flag.Float64("orbit_prob", 0.5, "Probability of a host being identified as orbit install [0, 1]")
munkiIssueProb = flag.Float64("munki_issue_prob", 0.5, "Probability of a host having munki issues (note that ~50% of hosts have munki installed) [0, 1]")
munkiIssueCount = flag.Int("munki_issue_count", 10, "Number of munki issues reported by hosts identified to have munki issues")
osTemplates = flag.String("os_templates", "mac10.14.6", fmt.Sprintf("Comma separated list of host OS templates to use (any of %v, with or without the .tmpl extension)", allowedTemplateNames))
emptySerialProb = flag.Float64("empty_serial_prob", 0.1, "Probability of a host having no serial number [0, 1]")
// E.g. when running with `-host_count=10`, you can set host count for each template the following way:
// `-os_templates=windows_11.tmpl:3,mac10.14.6.tmpl:4,ubuntu_22.04.tmpl:3`
osTemplates = flag.String("os_templates", "mac10.14.6", fmt.Sprintf("Comma separated list of host OS templates to use and optionally their host count separated by ':' (any of %v, with or without the .tmpl extension)", allowedTemplateNames))
emptySerialProb = flag.Float64("empty_serial_prob", 0.1, "Probability of a host having no serial number [0, 1]")
mdmProb = flag.Float64("mdm_prob", 0.0, "Probability of a host enrolling via MDM (for macOS) [0, 1]")
mdmSCEPChallenge = flag.String("mdm_scep_challenge", "", "SCEP challenge to use when running MDM enroll")
@ -1434,9 +1436,20 @@ func main() {
log.Fatalf("Argument unique_software_uninstall_count cannot be bigger than unique_software_count")
}
var tmpls []*template.Template
tmplsm := make(map[*template.Template]int)
requestedTemplates := strings.Split(*osTemplates, ",")
tmplsTotalHostCount := 0
for _, nm := range requestedTemplates {
numberOfHosts := 0
if strings.Contains(nm, ":") {
parts := strings.Split(nm, ":")
nm = parts[0]
hc, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
log.Fatalf("Invalid template host count: %s", parts[1])
}
numberOfHosts = int(hc)
}
if !strings.HasSuffix(nm, ".tmpl") {
nm += ".tmpl"
}
@ -1448,7 +1461,11 @@ func main() {
if err != nil {
log.Fatal("parse templates: ", err)
}
tmpls = append(tmpls, tmpl)
tmplsm[tmpl] = numberOfHosts
tmplsTotalHostCount += numberOfHosts
}
if tmplsTotalHostCount != 0 && tmplsTotalHostCount != *hostCount {
log.Fatalf("Invalid host count in templates: total=%d vs host_count=%d", tmplsTotalHostCount, *hostCount)
}
// Spread starts over the interval to prevent thundering herd
@ -1463,8 +1480,28 @@ func main() {
nodeKeyManager.LoadKeys()
}
var tmplss []*template.Template
for tmpl := range tmplsm {
tmplss = append(tmplss, tmpl)
}
for i := 0; i < *hostCount; i++ {
tmpl := tmpls[i%len(tmpls)]
var tmpl *template.Template
if tmplsTotalHostCount > 0 {
for tmpl_, hostCount := range tmplsm {
if hostCount > 0 {
tmpl = tmpl_
tmplsm[tmpl_] = tmplsm[tmpl_] - 1
break
}
}
if tmpl == nil {
log.Fatalf("Failed to determine template for host: %d", i)
}
} else {
tmpl = tmplss[i%len(tmplss)]
}
a := newAgent(i+1,
*serverURL,
*enrollSecret,

View file

@ -135,6 +135,7 @@ func TestHosts(t *testing.T) {
{"ReplaceHostBatteries", testHostsReplaceHostBatteries},
{"CountHostsNotResponding", testCountHostsNotResponding},
{"FailingPoliciesCount", testFailingPoliciesCount},
{"HostRecordNoPolicies", testHostsRecordNoPolicies},
{"SetOrUpdateHostDisksSpace", testHostsSetOrUpdateHostDisksSpace},
{"HostIDsByOSID", testHostIDsByOSID},
{"SetOrUpdateHostDisksEncryption", testHostsSetOrUpdateHostDisksEncryption},
@ -6159,6 +6160,55 @@ func testFailingPoliciesCount(t *testing.T, ds *Datastore) {
})
}
func testHostsRecordNoPolicies(t *testing.T, ds *Datastore) {
initialTime := time.Now()
for i := 0; i < 2; i++ {
_, err := ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: initialTime,
LabelUpdatedAt: initialTime,
PolicyUpdatedAt: initialTime,
SeenTime: initialTime.Add(-time.Duration(i) * time.Minute),
OsqueryHostID: ptr.String(strconv.Itoa(i)),
NodeKey: ptr.String(fmt.Sprintf("%d", i)),
UUID: fmt.Sprintf("%d", i),
Hostname: fmt.Sprintf("foo.local%d", i),
})
require.NoError(t, err)
}
filter := fleet.TeamFilter{User: test.UserAdmin}
hosts := listHostsCheckCount(t, ds, filter, fleet.HostListOptions{}, 2)
require.Len(t, hosts, 2)
h1 := hosts[0]
h2 := hosts[1]
assert.WithinDuration(t, initialTime, h1.PolicyUpdatedAt, 1*time.Second)
assert.Zero(t, h1.HostIssues.FailingPoliciesCount)
assert.Zero(t, h1.HostIssues.TotalIssuesCount)
assert.WithinDuration(t, initialTime, h2.PolicyUpdatedAt, 1*time.Second)
assert.Zero(t, h2.HostIssues.FailingPoliciesCount)
assert.Zero(t, h2.HostIssues.TotalIssuesCount)
policyUpdatedAt := initialTime.Add(1 * time.Hour)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, nil, policyUpdatedAt, false))
hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{}, 2)
require.Len(t, hosts, 2)
h1 = hosts[0]
h2 = hosts[1]
assert.WithinDuration(t, policyUpdatedAt, h1.PolicyUpdatedAt, 1*time.Second)
assert.Zero(t, h1.HostIssues.FailingPoliciesCount)
assert.Zero(t, h1.HostIssues.TotalIssuesCount)
assert.WithinDuration(t, initialTime, h2.PolicyUpdatedAt, 1*time.Second)
assert.Zero(t, h2.HostIssues.FailingPoliciesCount)
assert.Zero(t, h2.HostIssues.TotalIssuesCount)
}
func testHostsSetOrUpdateHostDisksSpace(t *testing.T, ds *Datastore) {
host, err := ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),

View file

@ -214,21 +214,23 @@ func filterNotExecuted(results map[uint]*bool) map[uint]bool {
}
func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error {
// Sort the results to have generated SQL queries ordered to minimize
// deadlocks. See https://github.com/fleetdm/fleet/issues/1146.
orderedIDs := make([]uint, 0, len(results))
for policyID := range results {
orderedIDs = append(orderedIDs, policyID)
}
sort.Slice(orderedIDs, func(i, j int) bool { return orderedIDs[i] < orderedIDs[j] })
// Loop through results, collecting which labels we need to insert/update
vals := []interface{}{}
bindvars := []string{}
for _, policyID := range orderedIDs {
matches := results[policyID]
bindvars = append(bindvars, "(?,?,?,?)")
vals = append(vals, updated, policyID, host.ID, matches)
if len(results) > 0 {
// Sort the results to have generated SQL queries ordered to minimize
// deadlocks. See https://github.com/fleetdm/fleet/issues/1146.
orderedIDs := make([]uint, 0, len(results))
for policyID := range results {
orderedIDs = append(orderedIDs, policyID)
}
sort.Slice(orderedIDs, func(i, j int) bool { return orderedIDs[i] < orderedIDs[j] })
// Loop through results, collecting which labels we need to insert/update
for _, policyID := range orderedIDs {
matches := results[policyID]
bindvars = append(bindvars, "(?,?,?,?)")
vals = append(vals, updated, policyID, host.ID, matches)
}
}
// NOTE: the insert of policy membership that follows must be kept in sync
@ -238,16 +240,17 @@ func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *flee
// semantically equivalent, even though here it processes a single host and
// in async mode it processes a batch of hosts).
query := fmt.Sprintf(
`INSERT INTO policy_membership (updated_at, policy_id, host_id, passes)
VALUES %s ON DUPLICATE KEY UPDATE updated_at=VALUES(updated_at), passes=VALUES(passes)`,
strings.Join(bindvars, ","),
)
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, query, vals...)
if err != nil {
return ctxerr.Wrapf(ctx, err, "insert policy_membership (%v)", vals)
if len(results) > 0 {
query := fmt.Sprintf(
`INSERT INTO policy_membership (updated_at, policy_id, host_id, passes)
VALUES %s ON DUPLICATE KEY UPDATE updated_at=VALUES(updated_at), passes=VALUES(passes)`,
strings.Join(bindvars, ","),
)
_, err := tx.ExecContext(ctx, query, vals...)
if err != nil {
return ctxerr.Wrapf(ctx, err, "insert policy_membership (%v)", vals)
}
}
// if we are deferring host updates, we return at this point and do the change outside of the tx
@ -255,8 +258,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *flee
return nil
}
_, err = tx.ExecContext(ctx, `UPDATE hosts SET policy_updated_at = ? WHERE id=?`, updated, host.ID)
if err != nil {
if _, err := tx.ExecContext(ctx, `UPDATE hosts SET policy_updated_at = ? WHERE id=?`, updated, host.ID); err != nil {
return ctxerr.Wrap(ctx, err, "updating hosts policy updated at")
}
return nil

View file

@ -667,6 +667,7 @@ type Datastore interface {
FlippingPoliciesForHost(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error)
// RecordPolicyQueryExecutions records the execution results of the policies for the given host.
// Even if `results` is empty, the host's `policy_updated_at` will be updated.
RecordPolicyQueryExecutions(ctx context.Context, host *Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error
// RecordLabelQueryExecutions saves the results of label queries. The results map is a map of label id -> whether or

View file

@ -22,10 +22,8 @@ const (
policyPassKeysMinTTL = 7 * 24 * time.Hour // 1 week
)
var (
// redis list will be LTRIM'd if there are more policy IDs than this.
maxRedisPolicyResultsPerHost = 1000
)
// redis list will be LTRIM'd if there are more policy IDs than this.
var maxRedisPolicyResultsPerHost = 1000
func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error {
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
@ -51,36 +49,57 @@ func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host
ttl = maxTTL
}
// KEYS[1]: keyList (policyPassHostKey)
// KEYS[2]: keyTs (policyPassReportedKey)
// ARGV[1]: timestamp for "reported at"
// ARGV[2]: max policy results to keep per host (list is trimmed to that size)
// ARGV[3]: ttl for both keys
// ARGV[4..]: policy_id=pass entries to LPUSH to the list
script := redigo.NewScript(2, `
redis.call('LPUSH', KEYS[1], unpack(ARGV, 4))
redis.call('LTRIM', KEYS[1], 0, ARGV[2])
redis.call('EXPIRE', KEYS[1], ARGV[3])
redis.call('SET', KEYS[2], ARGV[1])
return redis.call('EXPIRE', KEYS[2], ARGV[3])
`)
// There are two versions of the script (1) and (2).
// Script (1) is used when the are no policy results to report and
// script (2) is used when there are policy results.
// convert results to LPUSH arguments, store as policy_id=1 for pass,
// policy_id=-1 for fail, policy_id=0 for null result.
args := make(redigo.Args, 0, 5+len(results))
args = args.Add(keyList, keyTs, ts.Unix(), maxRedisPolicyResultsPerHost, int(ttl.Seconds()))
for k, v := range results {
pass := 0
if v != nil {
if *v {
pass = 1
} else {
pass = -1
// (1)
// KEYS[1]: keyTs (policyPassReportedKey)
// ARGV[1]: timestamp for "reported at"
// ARGV[2]: ttl for the key
scriptSrc := `
redis.call('SET', KEYS[1], ARGV[1])
return redis.call('EXPIRE', KEYS[1], ARGV[2])
`
keyCount := 1
args := make(redigo.Args, 0, 3)
args = args.Add(keyTs, ts.Unix(), int(ttl.Seconds()))
if len(results) > 0 {
// (2)
// KEYS[1]: keyList (policyPassHostKey)
// KEYS[2]: keyTs (policyPassReportedKey)
// ARGV[1]: timestamp for "reported at"
// ARGV[2]: max policy results to keep per host (list is trimmed to that size)
// ARGV[3]: ttl for both keys
// ARGV[4..]: policy_id=pass entries to LPUSH to the list
keyCount = 2
scriptSrc = `
redis.call('LPUSH', KEYS[1], unpack(ARGV, 4))
redis.call('LTRIM', KEYS[1], 0, ARGV[2])
redis.call('EXPIRE', KEYS[1], ARGV[3])
redis.call('SET', KEYS[2], ARGV[1])
return redis.call('EXPIRE', KEYS[2], ARGV[3])
`
// convert results to LPUSH arguments, store as policy_id=1 for pass,
// policy_id=-1 for fail, policy_id=0 for null result.
args = make(redigo.Args, 0, 5+len(results))
args = args.Add(keyList, keyTs, ts.Unix(), maxRedisPolicyResultsPerHost, int(ttl.Seconds()))
for k, v := range results {
pass := 0
if v != nil {
if *v {
pass = 1
} else {
pass = -1
}
}
args = args.Add(fmt.Sprintf("%d=%d", k, pass))
}
args = args.Add(fmt.Sprintf("%d=%d", k, pass))
}
script := redigo.NewScript(keyCount, scriptSrc)
conn := t.pool.Get()
defer conn.Close()
if err := redis.BindConn(t.pool, conn, keyList, keyTs); err != nil {

View file

@ -410,6 +410,115 @@ func testRecordPolicyQueryExecutionsAsync(t *testing.T, ds *mock.Store, pool fle
require.Equal(t, 0, count)
}
func testRecordPolicyQueryExecutionsNoPoliciesSync(t *testing.T, ds *mock.Store, pool fleet.RedisPool) {
ctx := context.Background()
now := time.Now()
lastYear := now.Add(-365 * 24 * time.Hour)
host := &fleet.Host{
ID: 1,
Platform: "linux",
PolicyUpdatedAt: lastYear,
}
var emptyResults map[uint]*bool
keyList, keyTs := fmt.Sprintf(policyPassHostKey, host.ID), fmt.Sprintf(policyPassReportedKey, host.ID)
task := NewTask(ds, pool, clock.C, config.OsqueryConfig{})
policyReportedAt := task.GetHostPolicyReportedAt(ctx, host)
require.True(t, policyReportedAt.Equal(lastYear))
err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false)
require.NoError(t, err)
require.True(t, ds.RecordPolicyQueryExecutionsFuncInvoked)
ds.RecordPolicyQueryExecutionsFuncInvoked = false
conn := redis.ConfigureDoer(pool, pool.Get())
defer conn.Close()
n, err := redigo.Int(conn.Do("EXISTS", keyList))
require.NoError(t, err)
require.Equal(t, 0, n)
n, err = redigo.Int(conn.Do("EXISTS", keyTs))
require.NoError(t, err)
require.Equal(t, 0, n)
n, err = redigo.Int(conn.Do("ZCARD", policyPassHostIDsKey))
require.NoError(t, err)
require.Equal(t, 0, n)
policyReportedAt = task.GetHostPolicyReportedAt(ctx, host)
require.True(t, policyReportedAt.Equal(now))
}
func testRecordPolicyQueryExecutionsNoPoliciesAsync(t *testing.T, ds *mock.Store, pool fleet.RedisPool) {
ctx := context.Background()
now := time.Now()
lastYear := now.Add(-365 * 24 * time.Hour)
host := &fleet.Host{
ID: 1,
Platform: "linux",
PolicyUpdatedAt: lastYear,
}
var emptyResults map[uint]*bool
keyList, keyTs := fmt.Sprintf(policyPassHostKey, host.ID), fmt.Sprintf(policyPassReportedKey, host.ID)
task := NewTask(ds, pool, clock.C, config.OsqueryConfig{
EnableAsyncHostProcessing: "true",
AsyncHostInsertBatch: 3,
AsyncHostUpdateBatch: 3,
AsyncHostDeleteBatch: 3,
AsyncHostRedisPopCount: 3,
AsyncHostRedisScanKeysCount: 10,
})
policyReportedAt := task.GetHostPolicyReportedAt(ctx, host)
require.True(t, policyReportedAt.Equal(lastYear))
err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false)
require.NoError(t, err)
require.False(t, ds.RecordPolicyQueryExecutionsFuncInvoked)
conn := redis.ConfigureDoer(pool, pool.Get())
defer conn.Close()
defer conn.Do("DEL", keyTs) //nolint:errcheck
n, err := redigo.Int(conn.Do("EXISTS", keyList))
require.NoError(t, err)
require.Equal(t, 0, n)
ts, err := redigo.Int64(conn.Do("GET", keyTs))
require.NoError(t, err)
require.Equal(t, now.Unix(), ts)
count, err := redigo.Int(conn.Do("ZCARD", policyPassHostIDsKey))
require.NoError(t, err)
require.Equal(t, 1, count)
tsActive, err := redigo.Int64(conn.Do("ZSCORE", policyPassHostIDsKey, host.ID))
require.NoError(t, err)
require.Equal(t, tsActive, ts)
policyReportedAt = task.GetHostPolicyReportedAt(ctx, host)
// because we transition via unix epoch (seconds), not exactly equal
require.WithinDuration(t, now, policyReportedAt, time.Second)
// host's PolicyUpdatedAt field hasn't been updated yet, because the policy
// results are in redis, not in mysql yet.
require.True(t, host.PolicyUpdatedAt.Equal(lastYear))
// running the collector removes the host from the active set
var stats collectorExecStats
err = task.collectPolicyQueryExecutions(ctx, ds, pool, &stats)
require.NoError(t, err)
require.Equal(t, 1, stats.Keys)
require.Equal(t, 0, stats.Items)
require.False(t, stats.Failed)
count, err = redigo.Int(conn.Do("ZCARD", policyPassHostIDsKey))
require.NoError(t, err)
require.Equal(t, 0, count)
}
func createPolicies(t *testing.T, ds *mysql.Datastore, count int) []uint {
ctx := context.Background()

View file

@ -126,12 +126,16 @@ func TestRecord(t *testing.T) {
pool := redistest.SetupRedis(t, "policy_pass", false, false, false)
t.Run("sync", func(t *testing.T) { testRecordPolicyQueryExecutionsSync(t, ds, pool) })
t.Run("async", func(t *testing.T) { testRecordPolicyQueryExecutionsAsync(t, ds, pool) })
t.Run("sync", func(t *testing.T) { testRecordPolicyQueryExecutionsNoPoliciesSync(t, ds, pool) })
t.Run("async", func(t *testing.T) { testRecordPolicyQueryExecutionsNoPoliciesAsync(t, ds, pool) })
})
t.Run("cluster", func(t *testing.T) {
pool := redistest.SetupRedis(t, "policy_pass", true, true, false)
t.Run("sync", func(t *testing.T) { testRecordPolicyQueryExecutionsSync(t, ds, pool) })
t.Run("async", func(t *testing.T) { testRecordPolicyQueryExecutionsAsync(t, ds, pool) })
t.Run("sync", func(t *testing.T) { testRecordPolicyQueryExecutionsNoPoliciesSync(t, ds, pool) })
t.Run("async", func(t *testing.T) { testRecordPolicyQueryExecutionsNoPoliciesAsync(t, ds, pool) })
})
})

View file

@ -614,13 +614,18 @@ func (svc *Service) GetDistributedQueries(ctx context.Context) (queries map[stri
}
}
policyQueries, err := svc.policyQueriesForHost(ctx, host)
policyQueries, noPolicies, err := svc.policyQueriesForHost(ctx, host)
if err != nil {
return nil, nil, 0, newOsqueryError(err.Error())
}
for name, query := range policyQueries {
queries[hostPolicyQueryPrefix+name] = query
}
if noPolicies {
// This is only set when it's time to re-run policies on the host,
// but the host doesn't have any policies assigned.
queries[hostNoPoliciesWildcard] = alwaysTrueQuery
}
accelerate = uint(0)
if host.Hostname == "" || host.Platform == "" {
@ -744,16 +749,22 @@ func (svc *Service) labelQueriesForHost(ctx context.Context, host *fleet.Host) (
return labelQueries, nil
}
func (svc *Service) policyQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) {
// policyQueriesForHost returns policy queries if it's the time to re-run policies on the given host.
// It returns (nil, true, nil) if the interval is so that policies should be executed on the host, but there are no policies
// assigned to such host.
func (svc *Service) policyQueriesForHost(ctx context.Context, host *fleet.Host) (policyQueries map[string]string, noPoliciesForHost bool, err error) {
policyReportedAt := svc.task.GetHostPolicyReportedAt(ctx, host)
if !svc.shouldUpdate(policyReportedAt, svc.config.Osquery.PolicyUpdateInterval, host.ID) && !host.RefetchRequested {
return nil, nil
return nil, false, nil
}
policyQueries, err := svc.ds.PolicyQueriesForHost(ctx, host)
policyQueries, err = svc.ds.PolicyQueriesForHost(ctx, host)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "retrieve policy queries")
return nil, false, ctxerr.Wrap(ctx, err, "retrieve policy queries")
}
return policyQueries, nil
if len(policyQueries) == 0 {
return nil, true, nil
}
return policyQueries, false, nil
}
////////////////////////////////////////////////////////////////////////////////
@ -867,6 +878,15 @@ const (
// osqueryd writes the distributed query results.
hostPolicyQueryPrefix = "fleet_policy_query_"
// hostNoPoliciesWildcard is a query sent to hosts when it's time to run policy
// queries on a host, but such host does not have any policies assigned.
// When Fleet receives results from such query then it will update the host's
// policy_updated_at column.
//
// This is used to prevent hosts without policies assigned to continuously
// perform lookups in the policies table on every check in.
hostNoPoliciesWildcard = "fleet_no_policies_wildcard"
// hostDistributedQueryPrefix is appended before the query name when a query is
// run from a distributed query campaign
hostDistributedQueryPrefix = "fleet_distributed_query_"
@ -895,7 +915,15 @@ func (svc *Service) SubmitDistributedQueryResults(
svc.maybeDebugHost(ctx, host, results, statuses, messages)
var hostWithoutPolicies bool
for query, rows := range results {
// When receiving this query in the results, we will update the host's
// policy_updated_at column.
if query == hostNoPoliciesWildcard {
hostWithoutPolicies = true
continue
}
// osquery docs say any nonzero (string) value for status indicates a query error
status, ok := statuses[query]
failed := ok && status != fleet.StatusOK
@ -973,6 +1001,13 @@ func (svc *Service) SubmitDistributedQueryResults(
if err := svc.task.RecordPolicyQueryExecutions(ctx, host, policyResults, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil {
logging.WithErr(ctx, err)
}
} else {
if hostWithoutPolicies {
// RecordPolicyQueryExecutions called with results=nil will still update the host's policy_updated_at column.
if err := svc.task.RecordPolicyQueryExecutions(ctx, host, nil, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil {
logging.WithErr(ctx, err)
}
}
}
if additionalUpdated {

View file

@ -847,7 +847,8 @@ func TestLabelQueries(t *testing.T) {
// should be turned on so that we can quickly fill labels)
queries, discovery, acc, err := svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform)), len(queries), distQueriesMapKeys(queries))
// +1 for the fleet_no_policies_wildcard query.
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+1, len(queries), distQueriesMapKeys(queries))
verifyDiscovery(t, queries, discovery)
assert.NotZero(t, acc)
@ -858,7 +859,7 @@ func TestLabelQueries(t *testing.T) {
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Empty(t, queries)
require.Len(t, queries, 1) // fleet_no_policies_wildcard query
verifyDiscovery(t, queries, discovery)
assert.Zero(t, acc)
@ -873,7 +874,8 @@ func TestLabelQueries(t *testing.T) {
// Now we should get the label queries
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Len(t, queries, 3)
// +1 for the fleet_no_policies_wildcard query.
require.Len(t, queries, 3+1)
verifyDiscovery(t, queries, discovery)
assert.Zero(t, acc)
@ -928,7 +930,7 @@ func TestLabelQueries(t *testing.T) {
ctx = hostctx.NewContext(ctx, host)
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Empty(t, queries)
require.Len(t, queries, 1) // fleet_no_policies_wildcard query
verifyDiscovery(t, queries, discovery)
assert.Zero(t, acc)
@ -937,8 +939,8 @@ func TestLabelQueries(t *testing.T) {
ctx = hostctx.NewContext(ctx, host)
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
// +3 for label queries
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+3, len(queries), distQueriesMapKeys(queries))
// +3 for label queries, +1 for the fleet_no_policies_wildcard query.
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+3+1, len(queries), distQueriesMapKeys(queries))
verifyDiscovery(t, queries, discovery)
assert.Zero(t, acc)
@ -967,7 +969,7 @@ func TestLabelQueries(t *testing.T) {
ctx = hostctx.NewContext(ctx, host)
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Empty(t, queries)
require.Len(t, queries, 1) // fleet_no_policies_wildcard query
verifyDiscovery(t, queries, discovery)
assert.Zero(t, acc)
}
@ -1006,8 +1008,8 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) {
// queries)
queries, discovery, acc, err := svc.GetDistributedQueries(ctx)
require.NoError(t, err)
// +1 due to 'windows_update_history'
if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+1, len(queries)) {
// +1 due to 'windows_update_history', +1 due to fleet_no_policies_wildcard query.
if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+1+1, len(queries)) {
// this is just to print the diff between the expected and actual query
// keys when the count assertion fails, to help debugging - they are not
// expected to match.
@ -1141,7 +1143,7 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) {
ctx = hostctx.NewContext(ctx, host)
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Empty(t, queries)
require.Len(t, queries, 1) // fleet_no_policies_wildcard query
verifyDiscovery(t, queries, discovery)
assert.Zero(t, acc)
@ -1152,7 +1154,8 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) {
require.NoError(t, err)
// somehow confusingly, the query response above changed the host's platform
// from windows to darwin
require.Equal(t, len(expectedDetailQueriesForPlatform(gotHost.Platform)), len(queries), distQueriesMapKeys(queries))
// +1 due to fleet_no_policies_wildcard query.
require.Equal(t, len(expectedDetailQueriesForPlatform(gotHost.Platform))+1, len(queries), distQueriesMapKeys(queries))
verifyDiscovery(t, queries, discovery)
assert.Zero(t, acc)
}
@ -1215,8 +1218,8 @@ func TestDetailQueries(t *testing.T) {
// queries)
queries, discovery, acc, err := svc.GetDistributedQueries(ctx)
require.NoError(t, err)
// +1 for software inventory
if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+1, len(queries)) {
// +1 for software inventory, +1 for fleet_no_policies_wildcard
if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+1+1, len(queries)) {
// this is just to print the diff between the expected and actual query
// keys when the count assertion fails, to help debugging - they are not
// expected to match.
@ -1468,7 +1471,7 @@ func TestDetailQueries(t *testing.T) {
ctx = hostctx.NewContext(ctx, host)
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Empty(t, queries)
require.Len(t, queries, 1) // fleet_no_policies_wildcard query
verifyDiscovery(t, queries, discovery)
assert.Zero(t, acc)
@ -1477,8 +1480,8 @@ func TestDetailQueries(t *testing.T) {
queries, discovery, acc, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
// +1 software inventory
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+1, len(queries), distQueriesMapKeys(queries))
// +1 software inventory, +1 fleet_no_policies_wildcard query
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+1+1, len(queries), distQueriesMapKeys(queries))
verifyDiscovery(t, queries, discovery)
assert.Zero(t, acc)
}
@ -1677,8 +1680,8 @@ func TestDistributedQueryResults(t *testing.T) {
// Now we should get the active distributed query
queries, discovery, acc, err := svc.GetDistributedQueries(hostCtx)
require.NoError(t, err)
// +1 for the distributed query for campaign ID 42, +1 for windows update history
if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+2, len(queries)) {
// +1 for the distributed query for campaign ID 42, +1 for windows update history, +1 for the fleet_no_policies_wildcard query.
if expected := expectedDetailQueriesForPlatform(host.Platform); !assert.Equal(t, len(expected)+3, len(queries)) {
// this is just to print the diff between the expected and actual query
// keys when the count assertion fails, to help debugging - they are not
// expected to match.
@ -2973,7 +2976,8 @@ func TestLiveQueriesFailing(t *testing.T) {
queries, discovery, _, err := svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform)), len(queries), distQueriesMapKeys(queries))
// +1 to account for the fleet_no_policies_wildcard query.
require.Equal(t, len(expectedDetailQueriesForPlatform(host.Platform))+1, len(queries), distQueriesMapKeys(queries))
verifyDiscovery(t, queries, discovery)
logs, err := io.ReadAll(buf)