fleet/server/service/async/async_policy.go
Lucas Manuel Rodriguez 9142c5de79
Prevent thundering herd when applying large number of policies on large number of hosts (#13552)
#13527

(Adding @mna to double check the changes in the async implementation of
policy result storage)

This PR also adds the osquery-perf changes needed to define the count of
macOS and Windows hosts.

- [X] Changes file added for user-visible changes in `changes/` or
`orbit/changes/`.
See [Changes
files](https://fleetdm.com/docs/contributing/committing-changes#changes-files)
for more information.
- ~[ ] Documented any API changes (docs/Using-Fleet/REST-API.md or
docs/Contributing/API-for-contributors.md)~
- ~[ ] Documented any permissions changes (docs/Using
Fleet/manage-access.md)~
- [X] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)
- [X] Added support on fleet's osquery simulator `cmd/osquery-perf` for
new osquery data ingestion features.
- [X] Added/updated tests
- [X] Manual QA for all new/changed functionality
  - ~For Orbit and Fleet Desktop changes:~
- ~[ ] Manual QA must be performed in the three main OSs, macOS, Windows
and Linux.~
- ~[ ] Auto-update manual QA, from released version of component to new
version (see [tools/tuf/test](../tools/tuf/test/README.md)).~

Test with 80k hosts: 70k simulated macOS, 10k simulated Windows.
Apply Windows policies first, then apply macOS policies:
```
fleetctl apply -f ee/cis/win-10/cis-policy-queries.yml

# Leave running for some time

fleetctl apply -f ee/cis/macos-13/cis-policy-queries.yml
```

After applying CIS policies previous to these changes:
![Screenshot 2023-08-23 at 11 36
18](https://github.com/fleetdm/fleet/assets/2073526/72c1dc7d-e601-4248-be35-93c85b749f5d)

After applying these changes and applying the same policies:
![Screenshot 2023-08-28 at 15 42
57](https://github.com/fleetdm/fleet/assets/2073526/6b6d76b8-6acb-4893-a913-bf603a68f1a4)
2023-08-31 10:58:50 -03:00

262 lines
7.9 KiB
Go

package async
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
redigo "github.com/gomodule/redigo/redis"
)
const (
policyPassHostIDsKey = "policy_pass:active_host_ids"
policyPassHostKey = "policy_pass:{%d}"
policyPassReportedKey = "policy_pass_reported:{%d}"
policyPassKeysMinTTL = 7 * 24 * time.Hour // 1 week
)
// redis list will be LTRIM'd if there are more policy IDs than this.
var maxRedisPolicyResultsPerHost = 1000
func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error {
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
if !cfg.Enabled {
host.PolicyUpdatedAt = ts
return t.datastore.RecordPolicyQueryExecutions(ctx, host, results, ts, deferred)
}
keyList := fmt.Sprintf(policyPassHostKey, host.ID)
keyTs := fmt.Sprintf(policyPassReportedKey, host.ID)
// set an expiration on both keys (list and ts), ensuring that a deleted host
// (eventually) does not use any redis space. Ensure that TTL is reasonably
// big to avoid deleting information that hasn't been collected yet - 1 week
// or 10 * the collector interval, whichever is biggest.
//
// This means that it will only expire if that host hasn't reported policies
// during that (TTL) time (each time it does report, the TTL is reset), and
// the collector will have plenty of time to run (multiple times) to try to
// persist all the data in mysql.
ttl := policyPassKeysMinTTL
if maxTTL := 10 * cfg.CollectInterval; maxTTL > ttl {
ttl = maxTTL
}
// There are two versions of the script (1) and (2).
// Script (1) is used when the are no policy results to report and
// script (2) is used when there are policy results.
// (1)
// KEYS[1]: keyTs (policyPassReportedKey)
// ARGV[1]: timestamp for "reported at"
// ARGV[2]: ttl for the key
scriptSrc := `
redis.call('SET', KEYS[1], ARGV[1])
return redis.call('EXPIRE', KEYS[1], ARGV[2])
`
keyCount := 1
args := make(redigo.Args, 0, 3)
args = args.Add(keyTs, ts.Unix(), int(ttl.Seconds()))
if len(results) > 0 {
// (2)
// KEYS[1]: keyList (policyPassHostKey)
// KEYS[2]: keyTs (policyPassReportedKey)
// ARGV[1]: timestamp for "reported at"
// ARGV[2]: max policy results to keep per host (list is trimmed to that size)
// ARGV[3]: ttl for both keys
// ARGV[4..]: policy_id=pass entries to LPUSH to the list
keyCount = 2
scriptSrc = `
redis.call('LPUSH', KEYS[1], unpack(ARGV, 4))
redis.call('LTRIM', KEYS[1], 0, ARGV[2])
redis.call('EXPIRE', KEYS[1], ARGV[3])
redis.call('SET', KEYS[2], ARGV[1])
return redis.call('EXPIRE', KEYS[2], ARGV[3])
`
// convert results to LPUSH arguments, store as policy_id=1 for pass,
// policy_id=-1 for fail, policy_id=0 for null result.
args = make(redigo.Args, 0, 5+len(results))
args = args.Add(keyList, keyTs, ts.Unix(), maxRedisPolicyResultsPerHost, int(ttl.Seconds()))
for k, v := range results {
pass := 0
if v != nil {
if *v {
pass = 1
} else {
pass = -1
}
}
args = args.Add(fmt.Sprintf("%d=%d", k, pass))
}
}
script := redigo.NewScript(keyCount, scriptSrc)
conn := t.pool.Get()
defer conn.Close()
if err := redis.BindConn(t.pool, conn, keyList, keyTs); err != nil {
return ctxerr.Wrap(ctx, err, "bind redis connection")
}
if _, err := script.Do(conn, args...); err != nil {
return ctxerr.Wrap(ctx, err, "run redis script")
}
// Storing the host id in the set of active host IDs for policy membership
// outside of the redis script because in Redis Cluster mode the key may not
// live on the same node as the host's keys. At the same time, purge any
// entry in the set that is older than now - TTL.
if _, err := storePurgeActiveHostID(t.pool, policyPassHostIDsKey, host.ID, ts, ts.Add(-ttl)); err != nil {
return ctxerr.Wrap(ctx, err, "store active host id")
}
return nil
}
func (t *Task) collectPolicyQueryExecutions(ctx context.Context, ds fleet.Datastore, pool fleet.RedisPool, stats *collectorExecStats) error {
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
hosts, err := loadActiveHostIDs(pool, policyPassHostIDsKey, cfg.RedisScanKeysCount)
if err != nil {
return ctxerr.Wrap(ctx, err, "load active host ids")
}
stats.Keys = len(hosts)
// need to use a script as the RPOP command only supports a COUNT since
// 6.2. Because we use LTRIM when inserting, we know the total number
// of results is at most maxRedisPolicyResultsPerHost, so it is capped
// and can be returned in one go.
script := redigo.NewScript(1, `
local res = redis.call('LRANGE', KEYS[1], 0, -1)
redis.call('DEL', KEYS[1])
return res
`)
getKeyTuples := func(hostID uint) (inserts []fleet.PolicyMembershipResult, err error) {
keyList := fmt.Sprintf(policyPassHostKey, hostID)
conn := redis.ConfigureDoer(pool, pool.Get())
defer conn.Close()
stats.RedisCmds++
res, err := redigo.Strings(script.Do(conn, keyList))
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "redis LRANGE script")
}
inserts = make([]fleet.PolicyMembershipResult, 0, len(res))
stats.Items += len(res)
for _, item := range res {
parts := strings.Split(item, "=")
if len(parts) != 2 {
continue
}
var tup fleet.PolicyMembershipResult
if id, _ := strconv.ParseUint(parts[0], 10, 32); id > 0 {
tup.HostID = hostID
tup.PolicyID = uint(id)
switch parts[1] {
case "1":
tup.Passes = ptr.Bool(true)
case "-1":
tup.Passes = ptr.Bool(false)
case "0":
tup.Passes = nil
default:
continue
}
inserts = append(inserts, tup)
}
}
return inserts, nil
}
runInsertBatch := func(batch []fleet.PolicyMembershipResult) error {
stats.Inserts++
return ds.AsyncBatchInsertPolicyMembership(ctx, batch)
}
runUpdateBatch := func(ids []uint, ts time.Time) error {
stats.Updates++
return ds.AsyncBatchUpdatePolicyTimestamp(ctx, ids, ts)
}
insertBatch := make([]fleet.PolicyMembershipResult, 0, cfg.InsertBatch)
for _, host := range hosts {
hid := host.HostID
ins, err := getKeyTuples(hid)
if err != nil {
return err
}
insertBatch = append(insertBatch, ins...)
if len(insertBatch) >= cfg.InsertBatch {
if err := runInsertBatch(insertBatch); err != nil {
return err
}
insertBatch = insertBatch[:0]
}
}
// process any remaining batch that did not reach the batchSize limit in the
// loop.
if len(insertBatch) > 0 {
if err := runInsertBatch(insertBatch); err != nil {
return err
}
}
if len(hosts) > 0 {
hostIDs := make([]uint, len(hosts))
for i, host := range hosts {
hostIDs[i] = host.HostID
}
ts := t.clock.Now()
updateBatch := make([]uint, cfg.UpdateBatch)
for {
n := copy(updateBatch, hostIDs)
if n == 0 {
break
}
if err := runUpdateBatch(updateBatch[:n], ts); err != nil {
return err
}
hostIDs = hostIDs[n:]
}
// batch-remove any host ID from the active set that still has its score to
// the initial value, so that the active set does not keep all (potentially
// 100K+) host IDs to process at all times - only those with reported
// results to process.
if _, err := removeProcessedHostIDs(pool, policyPassHostIDsKey, hosts); err != nil {
return ctxerr.Wrap(ctx, err, "remove processed host ids")
}
}
return nil
}
func (t *Task) GetHostPolicyReportedAt(ctx context.Context, host *fleet.Host) time.Time {
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
if cfg.Enabled {
conn := redis.ConfigureDoer(t.pool, t.pool.Get())
defer conn.Close()
key := fmt.Sprintf(policyPassReportedKey, host.ID)
epoch, err := redigo.Int64(conn.Do("GET", key))
if err == nil {
if reported := time.Unix(epoch, 0); reported.After(host.PolicyUpdatedAt) {
return reported
}
}
}
return host.PolicyUpdatedAt
}