Refactor async host processing to avoid redis SCAN keys (for labels only) (#3639)

This commit is contained in:
Martin Angers 2022-01-17 14:53:59 -05:00 committed by GitHub
parent a3553c4cc7
commit 1f185a7a8b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 230 additions and 75 deletions

View file

@ -0,0 +1 @@
* Refactor async host processing (`--enable_async_host_processing`) to avoid relying on slow redis SCAN keys.

View file

@ -278,6 +278,7 @@ the way that the Fleet server works.
UpdateBatch: config.Osquery.AsyncHostUpdateBatch,
RedisPopCount: config.Osquery.AsyncHostRedisPopCount,
RedisScanKeysCount: config.Osquery.AsyncHostRedisScanKeysCount,
CollectorInterval: config.Osquery.AsyncHostCollectInterval,
}
svc, err := service.NewService(ds, task, resultStore, logger, osqueryLogger, config, mailService, clock.C, ssoSessionStore, liveQueryStore, carveStore, *license, failingPolicySet)
if err != nil {
@ -534,8 +535,7 @@ func runCrons(ds fleet.Datastore, task *async.Task, logger kitlog.Logger, config
}
// StartCollectors starts a goroutine per collector, using ctx to cancel.
task.StartCollectors(ctx, config.Osquery.AsyncHostCollectInterval,
config.Osquery.AsyncHostCollectMaxJitterPercent, kitlog.With(logger, "cron", "async_task"))
task.StartCollectors(ctx, config.Osquery.AsyncHostCollectMaxJitterPercent, kitlog.With(logger, "cron", "async_task"))
go cronCleanups(ctx, ds, kitlog.With(logger, "cron", "cleanups"), ourIdentifier, license)
go cronVulnerabilities(

View file

@ -3,9 +3,6 @@ package async
import (
"context"
"fmt"
"math"
"regexp"
"strconv"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
@ -17,10 +14,11 @@ import (
)
const (
labelMembershipHostKeyPattern = "label_membership:{*}"
labelMembershipHostKey = "label_membership:{%d}"
labelMembershipReportedKey = "label_membership_reported:{%d}"
collectorLockKey = "locks:async_collector:{%s}"
labelMembershipActiveHostIDsKey = "label_membership:active_host_ids"
labelMembershipHostKey = "label_membership:{%d}"
labelMembershipReportedKey = "label_membership_reported:{%d}"
labelMembershipKeysMinTTL = 7 * 24 * time.Hour // 1 week
collectorLockKey = "locks:async_collector:{%s}"
)
type Task struct {
@ -37,23 +35,24 @@ type Task struct {
UpdateBatch int
RedisPopCount int
RedisScanKeysCount int
CollectorInterval time.Duration
}
// Collect runs the various collectors as distinct background goroutines if
// async processing is enabled. Each collector will stop processing when ctx
// is done.
func (t *Task) StartCollectors(ctx context.Context, interval time.Duration, jitterPct int, logger kitlog.Logger) {
func (t *Task) StartCollectors(ctx context.Context, jitterPct int, logger kitlog.Logger) {
if !t.AsyncEnabled {
level.Debug(logger).Log("task", "async disabled, not starting collectors")
return
}
level.Debug(logger).Log("task", "async enabled, starting collectors", "interval", interval, "jitter", jitterPct)
level.Debug(logger).Log("task", "async enabled, starting collectors", "interval", t.CollectorInterval, "jitter", jitterPct)
labelColl := &collector{
name: "collect_labels",
pool: t.Pool,
ds: t.Datastore,
execInterval: interval,
execInterval: t.CollectorInterval,
jitterPct: jitterPct,
lockTimeout: t.LockTimeout,
handler: t.collectLabelQueryExecutions,
@ -89,14 +88,36 @@ func (t *Task) RecordLabelQueryExecutions(ctx context.Context, host *fleet.Host,
keySet := fmt.Sprintf(labelMembershipHostKey, host.ID)
keyTs := fmt.Sprintf(labelMembershipReportedKey, host.ID)
// set an expiration on both keys (set and ts), ensuring that a deleted host
// (eventually) does not use any redis space. Ensure that TTL is reasonably
// big to avoid deleting information that hasn't been collected yet - 1 week
// or 10 * the collector interval, whichever is biggest.
//
// This means that it will only expire if that host hasn't reported labels
// during that (TTL) time (each time it does report, the TTL is reset), and
// the collector will have plenty of time to run (multiple times) to try to
// persist all the data in mysql.
ttl := labelMembershipKeysMinTTL
if maxTTL := 10 * t.CollectorInterval; maxTTL > ttl {
ttl = maxTTL
}
// keys and arguments passed to the script are:
// KEYS[1]: keySet (labelMembershipHostKey)
// KEYS[2]: keyTs (labelMembershipReportedKey)
// ARGV[1]: timestamp for "reported at"
// ARGV[2]: ttl for both keys
// ARGV[3..]: the arguments to ZADD to keySet
script := redigo.NewScript(2, `
redis.call('ZADD', KEYS[1], unpack(ARGV, 2))
return redis.call('SET', KEYS[2], ARGV[1])
redis.call('ZADD', KEYS[1], unpack(ARGV, 3))
redis.call('EXPIRE', KEYS[1], ARGV[2])
redis.call('SET', KEYS[2], ARGV[1])
return redis.call('EXPIRE', KEYS[2], ARGV[2])
`)
// convert results to ZADD arguments, store as -1 for delete, +1 for insert
args := make(redigo.Args, 0, 3+(len(results)*2))
args = args.Add(keySet, keyTs, ts.Unix())
args := make(redigo.Args, 0, 4+(len(results)*2))
args = args.Add(keySet, keyTs, ts.Unix(), int(ttl.Seconds()))
for k, v := range results {
score := -1
if v != nil && *v {
@ -114,40 +135,35 @@ func (t *Task) RecordLabelQueryExecutions(ctx context.Context, host *fleet.Host,
if _, err := script.Do(conn, args...); err != nil {
return ctxerr.Wrap(ctx, err, "run redis script")
}
// Storing the host id in the set of active host IDs for label membership
// outside of the redis script because in Redis Cluster mode the key may not
// live on the same node as the host's keys. At the same time, purge any
// entry in the set that is older than now - TTL.
if err := storePurgeActiveHostID(t.Pool, host.ID, ts, ts.Add(-ttl)); err != nil {
return ctxerr.Wrap(ctx, err, "store active host id")
}
return nil
}
var reHostFromKey = regexp.MustCompile(`\{(\d+)\}$`)
func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datastore, pool fleet.RedisPool, stats *collectorExecStats) error {
keys, err := redis.ScanKeys(pool, labelMembershipHostKeyPattern, t.RedisScanKeysCount)
hosts, err := loadActiveHostIDs(pool, t.RedisScanKeysCount)
if err != nil {
return err
return ctxerr.Wrap(ctx, err, "load active host ids")
}
stats.Keys = len(keys)
getKeyTuples := func(key string) (hostID uint, inserts, deletes [][2]uint, err error) {
if matches := reHostFromKey.FindStringSubmatch(key); matches != nil {
id, err := strconv.ParseInt(matches[1], 10, 64)
if err == nil && id > 0 && id <= math.MaxUint32 { // required for CodeQL vulnerability scanning in CI
hostID = uint(id)
}
}
// just ignore if there is no valid host id
if hostID == 0 {
return hostID, nil, nil, nil
}
stats.Keys = len(hosts)
getKeyTuples := func(hostID uint) (inserts, deletes [][2]uint, err error) {
keySet := fmt.Sprintf(labelMembershipHostKey, hostID)
conn := redis.ConfigureDoer(pool, pool.Get())
defer conn.Close()
for {
stats.RedisCmds++
vals, err := redigo.Ints(conn.Do("ZPOPMIN", key, t.RedisPopCount))
vals, err := redigo.Ints(conn.Do("ZPOPMIN", keySet, t.RedisPopCount))
if err != nil {
return hostID, nil, nil, ctxerr.Wrap(ctx, err, "redis ZPOPMIN")
return nil, nil, ctxerr.Wrap(ctx, err, "redis ZPOPMIN")
}
items := len(vals) / 2 // each item has the label id and the score (-1=delete, +1=insert)
stats.Items += items
@ -168,7 +184,7 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto
}
}
if items < t.RedisPopCount {
return hostID, inserts, deletes, nil
return inserts, deletes, nil
}
}
}
@ -204,9 +220,9 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto
insertBatch := make([][2]uint, 0, t.InsertBatch)
deleteBatch := make([][2]uint, 0, t.DeleteBatch)
hostIDs := make([]uint, 0, len(keys))
for _, key := range keys {
hostID, ins, del, err := getKeyTuples(key)
for _, host := range hosts {
hid := host.HostID
ins, del, err := getKeyTuples(hid)
if err != nil {
return err
}
@ -225,9 +241,6 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto
}
deleteBatch = deleteBatch[:0]
}
if hostID > 0 {
hostIDs = append(hostIDs, hostID)
}
}
// process any remaining batch that did not reach the batchSize limit in the
@ -242,7 +255,12 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto
return err
}
}
if len(hostIDs) > 0 {
if len(hosts) > 0 {
hostIDs := make([]uint, len(hosts))
for i, host := range hosts {
hostIDs[i] = host.HostID
}
ts := time.Now()
updateBatch := make([]uint, t.UpdateBatch)
for {
@ -255,6 +273,14 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto
}
hostIDs = hostIDs[n:]
}
// batch-remove any host ID from the active set that still has its score to
// the initial value, so that the active set does not keep all (potentially
// 100K+) host IDs to process at all times - only those with reported
// results to process.
if err := removeProcessedHostIDs(pool, hosts); err != nil {
return ctxerr.Wrap(ctx, err, "remove processed host ids")
}
}
return nil
@ -275,3 +301,112 @@ func (t *Task) GetHostLabelReportedAt(ctx context.Context, host *fleet.Host) tim
}
return host.LabelUpdatedAt
}
func storePurgeActiveHostID(pool fleet.RedisPool, hid uint, reportedAt, purgeOlder time.Time) error {
// KEYS[1]: labelMembershipActiveHostIDsKey
// ARGV[1]: the host ID to add
// ARGV[2]: the added host's reported-at timestamp
// ARGV[3]: purge any entry with score older than this (purgeOlder timestamp)
script := redigo.NewScript(1, `
redis.call('ZADD', KEYS[1], ARGV[2], ARGV[1])
return redis.call('ZREMRANGEBYSCORE', KEYS[1], '-inf', ARGV[3])
`)
conn := pool.Get()
defer conn.Close()
if err := redis.BindConn(pool, conn, labelMembershipActiveHostIDsKey); err != nil {
return fmt.Errorf("bind redis connection: %w", err)
}
if _, err := script.Do(conn, labelMembershipActiveHostIDsKey, hid, reportedAt.Unix(), purgeOlder.Unix()); err != nil {
return fmt.Errorf("run redis script: %w", err)
}
return nil
}
func removeProcessedHostIDs(pool fleet.RedisPool, batch []hostIDLastReported) error {
// This script removes from the set of active hosts for label membership all
// those that still have the same score as when the batch was read (via
// loadActiveHostIDs). This is so that any host that would've reported new
// data since the call to loadActiveHostIDs would *not* get deleted (as the
// score would change if that was the case).
//
// Note that this approach is correct - in that it is safe and won't delete
// any host that has unsaved reported data - but it is potentially slow, as
// it needs to check the score of each member before deleting it. Should that
// become too slow, we have some options:
//
// * split the batch in smaller, capped ones (that would be if the redis
// server gets blocked for too long processing a single batch)
// * use ZREMRANGEBYSCORE to remove in one command all members with a score
// (reported-at timestamp) lower than the maximum timestamp in batch.
// While this would be almost certainly faster, it might be incorrect as
// new data could be reported with timestamps older than the maximum one,
// e.g. if the clocks are not exactly in sync between fleet instances, or
// if hosts report new data while the ZSCAN is going on and don't get picked
// up by the SCAN (this is possible, as part of the guarantees of SCAN).
// KEYS[1]: labelMembershipActiveHostIDsKey
// ARGV...: the list of host ID-last reported timestamp pairs
script := redigo.NewScript(1, `
local count = 0
for i = 1, #ARGV, 2 do
local member, ts = ARGV[i], ARGV[i+1]
if redis.call('ZSCORE', KEYS[1], member) == ts then
count = count + 1
redis.call('ZREM', KEYS[1], member)
end
end
return count
`)
conn := pool.Get()
defer conn.Close()
if err := redis.BindConn(pool, conn, labelMembershipActiveHostIDsKey); err != nil {
return fmt.Errorf("bind redis connection: %w", err)
}
args := redigo.Args{labelMembershipActiveHostIDsKey}
for _, host := range batch {
args = args.Add(host.HostID, host.LastReported)
}
if _, err := script.Do(conn, args...); err != nil {
return fmt.Errorf("run redis script: %w", err)
}
return nil
}
type hostIDLastReported struct {
HostID uint
LastReported int64 // timestamp in unix epoch
}
func loadActiveHostIDs(pool fleet.RedisPool, scanCount int) ([]hostIDLastReported, error) {
conn := redis.ConfigureDoer(pool, pool.Get())
defer conn.Close()
// using ZSCAN instead of fetching in one shot, as there may be 100K+ hosts
// and we don't want to block the redis server too long.
var hosts []hostIDLastReported
cursor := 0
for {
res, err := redigo.Values(conn.Do("ZSCAN", labelMembershipActiveHostIDsKey, cursor, "COUNT", scanCount))
if err != nil {
return nil, fmt.Errorf("scan active host ids: %w", err)
}
var hostVals []uint
if _, err := redigo.Scan(res, &cursor, &hostVals); err != nil {
return nil, fmt.Errorf("convert scan results: %w", err)
}
for i := 0; i < len(hostVals); i += 2 {
hosts = append(hosts, hostIDLastReported{HostID: hostVals[i], LastReported: int64(hostVals[i+1])})
}
if cursor == 0 {
// iteration completed
return hosts, nil
}
}
}

View file

@ -43,9 +43,6 @@ func testCollectLabelQueryExecutions(t *testing.T, ds *mysql.Datastore, pool fle
hostIDs := createHosts(t, ds, 4, time.Now().Add(-24*time.Hour))
hid := func(id int) int {
if id < 0 {
return id
}
return int(hostIDs[id-1])
}
@ -53,9 +50,7 @@ func testCollectLabelQueryExecutions(t *testing.T, ds *mysql.Datastore, pool fle
// previous one's state, so they are not run as distinct sub-tests.
cases := []struct {
name string
// map of host ID to label IDs to insert (true) or delete (false), a
// negative host id is stored as an invalid redis key that should be
// ignored.
// map of host ID to label IDs to insert (true) or delete (false)
reported map[int]map[int]bool
want []labelMembership
}{
@ -135,23 +130,12 @@ func testCollectLabelQueryExecutions(t *testing.T, ds *mysql.Datastore, pool fle
},
},
{
"report host -99 labels 1, ignored",
map[int]map[int]bool{hid(-99): {1: true}},
[]labelMembership{
{HostID: hid(1), LabelID: 1},
{HostID: hid(1), LabelID: 2},
{HostID: hid(2), LabelID: 2},
{HostID: hid(2), LabelID: 3},
},
},
{
"report hosts 1, 2, 3, 4, -99 labels 1, 2, -3, 4",
"report hosts 1, 2, 3, 4 labels 1, 2, -3, 4",
map[int]map[int]bool{
hid(1): {1: true, 2: true, 3: false, 4: true},
hid(2): {1: true, 2: true, 3: false, 4: true},
hid(3): {1: true, 2: true, 3: false, 4: true},
hid(4): {1: true, 2: true, 3: false, 4: true},
hid(-99): {1: true, 2: true, 3: false, 4: true},
hid(1): {1: true, 2: true, 3: false, 4: true},
hid(2): {1: true, 2: true, 3: false, 4: true},
hid(3): {1: true, 2: true, 3: false, 4: true},
hid(4): {1: true, 2: true, 3: false, 4: true},
},
[]labelMembership{
{HostID: hid(1), LabelID: 1},
@ -192,13 +176,16 @@ func testCollectLabelQueryExecutions(t *testing.T, ds *mysql.Datastore, pool fle
}
_, err := conn.Do("ZADD", args...)
require.NoError(t, err)
_, err = conn.Do("ZADD", labelMembershipActiveHostIDsKey, time.Now().Unix(), hostID)
require.NoError(t, err)
}
wantStats.Keys++
if hostID >= 0 {
wantStats.Items += len(res)
wantStats.RedisCmds++
wantStats.RedisCmds += len(res) / batchSizes
}
cnt, err := redigo.Int(conn.Do("ZCARD", labelMembershipActiveHostIDsKey))
require.NoError(t, err)
wantStats.Keys = cnt
wantStats.Items += len(res)
wantStats.RedisCmds++
wantStats.RedisCmds += len(res) / batchSizes
}
return wantStats
}
@ -305,6 +292,9 @@ func TestRecordLabelQueryExecutions(t *testing.T) {
ds.RecordLabelQueryExecutionsFunc = func(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error {
return nil
}
ds.AsyncBatchUpdateLabelTimestampFunc = func(ctx context.Context, ids []uint, ts time.Time) error {
return nil
}
t.Run("standalone", func(t *testing.T) {
pool := redistest.SetupRedis(t, false, false, false)
@ -313,7 +303,7 @@ func TestRecordLabelQueryExecutions(t *testing.T) {
})
t.Run("cluster", func(t *testing.T) {
pool := redistest.SetupRedis(t, true, false, false)
pool := redistest.SetupRedis(t, true, true, false)
t.Run("sync", func(t *testing.T) { testRecordLabelQueryExecutionsSync(t, ds, pool) })
t.Run("async", func(t *testing.T) { testRecordLabelQueryExecutionsAsync(t, ds, pool) })
})
@ -347,7 +337,7 @@ func testRecordLabelQueryExecutionsSync(t *testing.T, ds *mock.Store, pool fleet
require.True(t, ds.RecordLabelQueryExecutionsFuncInvoked)
ds.RecordLabelQueryExecutionsFuncInvoked = false
conn := pool.Get()
conn := redis.ConfigureDoer(pool, pool.Get())
defer conn.Close()
defer conn.Do("DEL", keySet, keyTs)
@ -359,6 +349,10 @@ func testRecordLabelQueryExecutionsSync(t *testing.T, ds *mock.Store, pool fleet
require.NoError(t, err)
require.Equal(t, 0, n)
n, err = redigo.Int(conn.Do("ZCARD", labelMembershipActiveHostIDsKey))
require.NoError(t, err)
require.Equal(t, 0, n)
labelReportedAt = task.GetHostLabelReportedAt(ctx, host)
require.True(t, labelReportedAt.Equal(now))
}
@ -380,6 +374,12 @@ func testRecordLabelQueryExecutionsAsync(t *testing.T, ds *mock.Store, pool flee
Datastore: ds,
Pool: pool,
AsyncEnabled: true,
InsertBatch: 3,
UpdateBatch: 3,
DeleteBatch: 3,
RedisPopCount: 3,
RedisScanKeysCount: 10,
}
labelReportedAt := task.GetHostLabelReportedAt(ctx, host)
@ -389,7 +389,7 @@ func testRecordLabelQueryExecutionsAsync(t *testing.T, ds *mock.Store, pool flee
require.NoError(t, err)
require.False(t, ds.RecordLabelQueryExecutionsFuncInvoked)
conn := pool.Get()
conn := redis.ConfigureDoer(pool, pool.Get())
defer conn.Close()
defer conn.Do("DEL", keySet, keyTs)
@ -402,12 +402,31 @@ func testRecordLabelQueryExecutionsAsync(t *testing.T, ds *mock.Store, pool flee
require.NoError(t, err)
require.Equal(t, now.Unix(), ts)
count, err := redigo.Int(conn.Do("ZCARD", labelMembershipActiveHostIDsKey))
require.NoError(t, err)
require.Equal(t, 1, count)
tsActive, err := redigo.Int64(conn.Do("ZSCORE", labelMembershipActiveHostIDsKey, host.ID))
require.NoError(t, err)
require.Equal(t, tsActive, ts)
labelReportedAt = task.GetHostLabelReportedAt(ctx, host)
// because we transition via unix epoch (seconds), not exactly equal
require.WithinDuration(t, now, labelReportedAt, time.Second)
// host's LabelUpdatedAt field hasn't been updated yet, because the label
// results are in redis, not in mysql yet.
require.True(t, host.LabelUpdatedAt.Equal(lastYear))
// running the collector removes the host from the active set
var stats collectorExecStats
err = task.collectLabelQueryExecutions(ctx, ds, pool, &stats)
require.NoError(t, err)
require.Equal(t, 1, stats.Keys)
require.Equal(t, 0, stats.Items) // zero because we cleared the host's set with ZPOPMIN above
require.False(t, stats.Failed)
count, err = redigo.Int(conn.Do("ZCARD", labelMembershipActiveHostIDsKey))
require.NoError(t, err)
require.Equal(t, 0, count)
}
func createHosts(t *testing.T, ds fleet.Datastore, count int, ts time.Time) []uint {