From 1f185a7a8baa8cd3e026b095b72db5ff28184942 Mon Sep 17 00:00:00 2001 From: Martin Angers Date: Mon, 17 Jan 2022 14:53:59 -0500 Subject: [PATCH] Refactor async host processing to avoid redis SCAN keys (for labels only) (#3639) --- changes/issue-3422-async-labels-avoid-scan | 1 + cmd/fleet/serve.go | 4 +- server/service/async/async.go | 219 +++++++++++++++++---- server/service/async/async_test.go | 81 +++++--- 4 files changed, 230 insertions(+), 75 deletions(-) create mode 100644 changes/issue-3422-async-labels-avoid-scan diff --git a/changes/issue-3422-async-labels-avoid-scan b/changes/issue-3422-async-labels-avoid-scan new file mode 100644 index 0000000000..85017639be --- /dev/null +++ b/changes/issue-3422-async-labels-avoid-scan @@ -0,0 +1 @@ +* Refactor async host processing (`--enable_async_host_processing`) to avoid relying on slow redis SCAN keys. diff --git a/cmd/fleet/serve.go b/cmd/fleet/serve.go index 6f6bcd45f0..9d5c9cc22e 100644 --- a/cmd/fleet/serve.go +++ b/cmd/fleet/serve.go @@ -278,6 +278,7 @@ the way that the Fleet server works. UpdateBatch: config.Osquery.AsyncHostUpdateBatch, RedisPopCount: config.Osquery.AsyncHostRedisPopCount, RedisScanKeysCount: config.Osquery.AsyncHostRedisScanKeysCount, + CollectorInterval: config.Osquery.AsyncHostCollectInterval, } svc, err := service.NewService(ds, task, resultStore, logger, osqueryLogger, config, mailService, clock.C, ssoSessionStore, liveQueryStore, carveStore, *license, failingPolicySet) if err != nil { @@ -534,8 +535,7 @@ func runCrons(ds fleet.Datastore, task *async.Task, logger kitlog.Logger, config } // StartCollectors starts a goroutine per collector, using ctx to cancel. - task.StartCollectors(ctx, config.Osquery.AsyncHostCollectInterval, - config.Osquery.AsyncHostCollectMaxJitterPercent, kitlog.With(logger, "cron", "async_task")) + task.StartCollectors(ctx, config.Osquery.AsyncHostCollectMaxJitterPercent, kitlog.With(logger, "cron", "async_task")) go cronCleanups(ctx, ds, kitlog.With(logger, "cron", "cleanups"), ourIdentifier, license) go cronVulnerabilities( diff --git a/server/service/async/async.go b/server/service/async/async.go index 320441a8d3..979dfc61cd 100644 --- a/server/service/async/async.go +++ b/server/service/async/async.go @@ -3,9 +3,6 @@ package async import ( "context" "fmt" - "math" - "regexp" - "strconv" "time" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" @@ -17,10 +14,11 @@ import ( ) const ( - labelMembershipHostKeyPattern = "label_membership:{*}" - labelMembershipHostKey = "label_membership:{%d}" - labelMembershipReportedKey = "label_membership_reported:{%d}" - collectorLockKey = "locks:async_collector:{%s}" + labelMembershipActiveHostIDsKey = "label_membership:active_host_ids" + labelMembershipHostKey = "label_membership:{%d}" + labelMembershipReportedKey = "label_membership_reported:{%d}" + labelMembershipKeysMinTTL = 7 * 24 * time.Hour // 1 week + collectorLockKey = "locks:async_collector:{%s}" ) type Task struct { @@ -37,23 +35,24 @@ type Task struct { UpdateBatch int RedisPopCount int RedisScanKeysCount int + CollectorInterval time.Duration } // Collect runs the various collectors as distinct background goroutines if // async processing is enabled. Each collector will stop processing when ctx // is done. -func (t *Task) StartCollectors(ctx context.Context, interval time.Duration, jitterPct int, logger kitlog.Logger) { +func (t *Task) StartCollectors(ctx context.Context, jitterPct int, logger kitlog.Logger) { if !t.AsyncEnabled { level.Debug(logger).Log("task", "async disabled, not starting collectors") return } - level.Debug(logger).Log("task", "async enabled, starting collectors", "interval", interval, "jitter", jitterPct) + level.Debug(logger).Log("task", "async enabled, starting collectors", "interval", t.CollectorInterval, "jitter", jitterPct) labelColl := &collector{ name: "collect_labels", pool: t.Pool, ds: t.Datastore, - execInterval: interval, + execInterval: t.CollectorInterval, jitterPct: jitterPct, lockTimeout: t.LockTimeout, handler: t.collectLabelQueryExecutions, @@ -89,14 +88,36 @@ func (t *Task) RecordLabelQueryExecutions(ctx context.Context, host *fleet.Host, keySet := fmt.Sprintf(labelMembershipHostKey, host.ID) keyTs := fmt.Sprintf(labelMembershipReportedKey, host.ID) + // set an expiration on both keys (set and ts), ensuring that a deleted host + // (eventually) does not use any redis space. Ensure that TTL is reasonably + // big to avoid deleting information that hasn't been collected yet - 1 week + // or 10 * the collector interval, whichever is biggest. + // + // This means that it will only expire if that host hasn't reported labels + // during that (TTL) time (each time it does report, the TTL is reset), and + // the collector will have plenty of time to run (multiple times) to try to + // persist all the data in mysql. + ttl := labelMembershipKeysMinTTL + if maxTTL := 10 * t.CollectorInterval; maxTTL > ttl { + ttl = maxTTL + } + + // keys and arguments passed to the script are: + // KEYS[1]: keySet (labelMembershipHostKey) + // KEYS[2]: keyTs (labelMembershipReportedKey) + // ARGV[1]: timestamp for "reported at" + // ARGV[2]: ttl for both keys + // ARGV[3..]: the arguments to ZADD to keySet script := redigo.NewScript(2, ` - redis.call('ZADD', KEYS[1], unpack(ARGV, 2)) - return redis.call('SET', KEYS[2], ARGV[1]) + redis.call('ZADD', KEYS[1], unpack(ARGV, 3)) + redis.call('EXPIRE', KEYS[1], ARGV[2]) + redis.call('SET', KEYS[2], ARGV[1]) + return redis.call('EXPIRE', KEYS[2], ARGV[2]) `) // convert results to ZADD arguments, store as -1 for delete, +1 for insert - args := make(redigo.Args, 0, 3+(len(results)*2)) - args = args.Add(keySet, keyTs, ts.Unix()) + args := make(redigo.Args, 0, 4+(len(results)*2)) + args = args.Add(keySet, keyTs, ts.Unix(), int(ttl.Seconds())) for k, v := range results { score := -1 if v != nil && *v { @@ -114,40 +135,35 @@ func (t *Task) RecordLabelQueryExecutions(ctx context.Context, host *fleet.Host, if _, err := script.Do(conn, args...); err != nil { return ctxerr.Wrap(ctx, err, "run redis script") } + + // Storing the host id in the set of active host IDs for label membership + // outside of the redis script because in Redis Cluster mode the key may not + // live on the same node as the host's keys. At the same time, purge any + // entry in the set that is older than now - TTL. + if err := storePurgeActiveHostID(t.Pool, host.ID, ts, ts.Add(-ttl)); err != nil { + return ctxerr.Wrap(ctx, err, "store active host id") + } return nil } -var reHostFromKey = regexp.MustCompile(`\{(\d+)\}$`) - func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datastore, pool fleet.RedisPool, stats *collectorExecStats) error { - keys, err := redis.ScanKeys(pool, labelMembershipHostKeyPattern, t.RedisScanKeysCount) + hosts, err := loadActiveHostIDs(pool, t.RedisScanKeysCount) if err != nil { - return err + return ctxerr.Wrap(ctx, err, "load active host ids") } - stats.Keys = len(keys) - - getKeyTuples := func(key string) (hostID uint, inserts, deletes [][2]uint, err error) { - if matches := reHostFromKey.FindStringSubmatch(key); matches != nil { - id, err := strconv.ParseInt(matches[1], 10, 64) - if err == nil && id > 0 && id <= math.MaxUint32 { // required for CodeQL vulnerability scanning in CI - hostID = uint(id) - } - } - - // just ignore if there is no valid host id - if hostID == 0 { - return hostID, nil, nil, nil - } + stats.Keys = len(hosts) + getKeyTuples := func(hostID uint) (inserts, deletes [][2]uint, err error) { + keySet := fmt.Sprintf(labelMembershipHostKey, hostID) conn := redis.ConfigureDoer(pool, pool.Get()) defer conn.Close() for { stats.RedisCmds++ - vals, err := redigo.Ints(conn.Do("ZPOPMIN", key, t.RedisPopCount)) + vals, err := redigo.Ints(conn.Do("ZPOPMIN", keySet, t.RedisPopCount)) if err != nil { - return hostID, nil, nil, ctxerr.Wrap(ctx, err, "redis ZPOPMIN") + return nil, nil, ctxerr.Wrap(ctx, err, "redis ZPOPMIN") } items := len(vals) / 2 // each item has the label id and the score (-1=delete, +1=insert) stats.Items += items @@ -168,7 +184,7 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto } } if items < t.RedisPopCount { - return hostID, inserts, deletes, nil + return inserts, deletes, nil } } } @@ -204,9 +220,9 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto insertBatch := make([][2]uint, 0, t.InsertBatch) deleteBatch := make([][2]uint, 0, t.DeleteBatch) - hostIDs := make([]uint, 0, len(keys)) - for _, key := range keys { - hostID, ins, del, err := getKeyTuples(key) + for _, host := range hosts { + hid := host.HostID + ins, del, err := getKeyTuples(hid) if err != nil { return err } @@ -225,9 +241,6 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto } deleteBatch = deleteBatch[:0] } - if hostID > 0 { - hostIDs = append(hostIDs, hostID) - } } // process any remaining batch that did not reach the batchSize limit in the @@ -242,7 +255,12 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto return err } } - if len(hostIDs) > 0 { + if len(hosts) > 0 { + hostIDs := make([]uint, len(hosts)) + for i, host := range hosts { + hostIDs[i] = host.HostID + } + ts := time.Now() updateBatch := make([]uint, t.UpdateBatch) for { @@ -255,6 +273,14 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto } hostIDs = hostIDs[n:] } + + // batch-remove any host ID from the active set that still has its score to + // the initial value, so that the active set does not keep all (potentially + // 100K+) host IDs to process at all times - only those with reported + // results to process. + if err := removeProcessedHostIDs(pool, hosts); err != nil { + return ctxerr.Wrap(ctx, err, "remove processed host ids") + } } return nil @@ -275,3 +301,112 @@ func (t *Task) GetHostLabelReportedAt(ctx context.Context, host *fleet.Host) tim } return host.LabelUpdatedAt } + +func storePurgeActiveHostID(pool fleet.RedisPool, hid uint, reportedAt, purgeOlder time.Time) error { + // KEYS[1]: labelMembershipActiveHostIDsKey + // ARGV[1]: the host ID to add + // ARGV[2]: the added host's reported-at timestamp + // ARGV[3]: purge any entry with score older than this (purgeOlder timestamp) + script := redigo.NewScript(1, ` + redis.call('ZADD', KEYS[1], ARGV[2], ARGV[1]) + return redis.call('ZREMRANGEBYSCORE', KEYS[1], '-inf', ARGV[3]) + `) + + conn := pool.Get() + defer conn.Close() + + if err := redis.BindConn(pool, conn, labelMembershipActiveHostIDsKey); err != nil { + return fmt.Errorf("bind redis connection: %w", err) + } + + if _, err := script.Do(conn, labelMembershipActiveHostIDsKey, hid, reportedAt.Unix(), purgeOlder.Unix()); err != nil { + return fmt.Errorf("run redis script: %w", err) + } + return nil +} + +func removeProcessedHostIDs(pool fleet.RedisPool, batch []hostIDLastReported) error { + // This script removes from the set of active hosts for label membership all + // those that still have the same score as when the batch was read (via + // loadActiveHostIDs). This is so that any host that would've reported new + // data since the call to loadActiveHostIDs would *not* get deleted (as the + // score would change if that was the case). + // + // Note that this approach is correct - in that it is safe and won't delete + // any host that has unsaved reported data - but it is potentially slow, as + // it needs to check the score of each member before deleting it. Should that + // become too slow, we have some options: + // + // * split the batch in smaller, capped ones (that would be if the redis + // server gets blocked for too long processing a single batch) + // * use ZREMRANGEBYSCORE to remove in one command all members with a score + // (reported-at timestamp) lower than the maximum timestamp in batch. + // While this would be almost certainly faster, it might be incorrect as + // new data could be reported with timestamps older than the maximum one, + // e.g. if the clocks are not exactly in sync between fleet instances, or + // if hosts report new data while the ZSCAN is going on and don't get picked + // up by the SCAN (this is possible, as part of the guarantees of SCAN). + + // KEYS[1]: labelMembershipActiveHostIDsKey + // ARGV...: the list of host ID-last reported timestamp pairs + script := redigo.NewScript(1, ` + local count = 0 + for i = 1, #ARGV, 2 do + local member, ts = ARGV[i], ARGV[i+1] + if redis.call('ZSCORE', KEYS[1], member) == ts then + count = count + 1 + redis.call('ZREM', KEYS[1], member) + end + end + return count + `) + + conn := pool.Get() + defer conn.Close() + + if err := redis.BindConn(pool, conn, labelMembershipActiveHostIDsKey); err != nil { + return fmt.Errorf("bind redis connection: %w", err) + } + + args := redigo.Args{labelMembershipActiveHostIDsKey} + for _, host := range batch { + args = args.Add(host.HostID, host.LastReported) + } + if _, err := script.Do(conn, args...); err != nil { + return fmt.Errorf("run redis script: %w", err) + } + return nil +} + +type hostIDLastReported struct { + HostID uint + LastReported int64 // timestamp in unix epoch +} + +func loadActiveHostIDs(pool fleet.RedisPool, scanCount int) ([]hostIDLastReported, error) { + conn := redis.ConfigureDoer(pool, pool.Get()) + defer conn.Close() + + // using ZSCAN instead of fetching in one shot, as there may be 100K+ hosts + // and we don't want to block the redis server too long. + var hosts []hostIDLastReported + cursor := 0 + for { + res, err := redigo.Values(conn.Do("ZSCAN", labelMembershipActiveHostIDsKey, cursor, "COUNT", scanCount)) + if err != nil { + return nil, fmt.Errorf("scan active host ids: %w", err) + } + var hostVals []uint + if _, err := redigo.Scan(res, &cursor, &hostVals); err != nil { + return nil, fmt.Errorf("convert scan results: %w", err) + } + for i := 0; i < len(hostVals); i += 2 { + hosts = append(hosts, hostIDLastReported{HostID: hostVals[i], LastReported: int64(hostVals[i+1])}) + } + + if cursor == 0 { + // iteration completed + return hosts, nil + } + } +} diff --git a/server/service/async/async_test.go b/server/service/async/async_test.go index 9805adae7e..71f6e6dcc8 100644 --- a/server/service/async/async_test.go +++ b/server/service/async/async_test.go @@ -43,9 +43,6 @@ func testCollectLabelQueryExecutions(t *testing.T, ds *mysql.Datastore, pool fle hostIDs := createHosts(t, ds, 4, time.Now().Add(-24*time.Hour)) hid := func(id int) int { - if id < 0 { - return id - } return int(hostIDs[id-1]) } @@ -53,9 +50,7 @@ func testCollectLabelQueryExecutions(t *testing.T, ds *mysql.Datastore, pool fle // previous one's state, so they are not run as distinct sub-tests. cases := []struct { name string - // map of host ID to label IDs to insert (true) or delete (false), a - // negative host id is stored as an invalid redis key that should be - // ignored. + // map of host ID to label IDs to insert (true) or delete (false) reported map[int]map[int]bool want []labelMembership }{ @@ -135,23 +130,12 @@ func testCollectLabelQueryExecutions(t *testing.T, ds *mysql.Datastore, pool fle }, }, { - "report host -99 labels 1, ignored", - map[int]map[int]bool{hid(-99): {1: true}}, - []labelMembership{ - {HostID: hid(1), LabelID: 1}, - {HostID: hid(1), LabelID: 2}, - {HostID: hid(2), LabelID: 2}, - {HostID: hid(2), LabelID: 3}, - }, - }, - { - "report hosts 1, 2, 3, 4, -99 labels 1, 2, -3, 4", + "report hosts 1, 2, 3, 4 labels 1, 2, -3, 4", map[int]map[int]bool{ - hid(1): {1: true, 2: true, 3: false, 4: true}, - hid(2): {1: true, 2: true, 3: false, 4: true}, - hid(3): {1: true, 2: true, 3: false, 4: true}, - hid(4): {1: true, 2: true, 3: false, 4: true}, - hid(-99): {1: true, 2: true, 3: false, 4: true}, + hid(1): {1: true, 2: true, 3: false, 4: true}, + hid(2): {1: true, 2: true, 3: false, 4: true}, + hid(3): {1: true, 2: true, 3: false, 4: true}, + hid(4): {1: true, 2: true, 3: false, 4: true}, }, []labelMembership{ {HostID: hid(1), LabelID: 1}, @@ -192,13 +176,16 @@ func testCollectLabelQueryExecutions(t *testing.T, ds *mysql.Datastore, pool fle } _, err := conn.Do("ZADD", args...) require.NoError(t, err) + _, err = conn.Do("ZADD", labelMembershipActiveHostIDsKey, time.Now().Unix(), hostID) + require.NoError(t, err) } - wantStats.Keys++ - if hostID >= 0 { - wantStats.Items += len(res) - wantStats.RedisCmds++ - wantStats.RedisCmds += len(res) / batchSizes - } + + cnt, err := redigo.Int(conn.Do("ZCARD", labelMembershipActiveHostIDsKey)) + require.NoError(t, err) + wantStats.Keys = cnt + wantStats.Items += len(res) + wantStats.RedisCmds++ + wantStats.RedisCmds += len(res) / batchSizes } return wantStats } @@ -305,6 +292,9 @@ func TestRecordLabelQueryExecutions(t *testing.T) { ds.RecordLabelQueryExecutionsFunc = func(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error { return nil } + ds.AsyncBatchUpdateLabelTimestampFunc = func(ctx context.Context, ids []uint, ts time.Time) error { + return nil + } t.Run("standalone", func(t *testing.T) { pool := redistest.SetupRedis(t, false, false, false) @@ -313,7 +303,7 @@ func TestRecordLabelQueryExecutions(t *testing.T) { }) t.Run("cluster", func(t *testing.T) { - pool := redistest.SetupRedis(t, true, false, false) + pool := redistest.SetupRedis(t, true, true, false) t.Run("sync", func(t *testing.T) { testRecordLabelQueryExecutionsSync(t, ds, pool) }) t.Run("async", func(t *testing.T) { testRecordLabelQueryExecutionsAsync(t, ds, pool) }) }) @@ -347,7 +337,7 @@ func testRecordLabelQueryExecutionsSync(t *testing.T, ds *mock.Store, pool fleet require.True(t, ds.RecordLabelQueryExecutionsFuncInvoked) ds.RecordLabelQueryExecutionsFuncInvoked = false - conn := pool.Get() + conn := redis.ConfigureDoer(pool, pool.Get()) defer conn.Close() defer conn.Do("DEL", keySet, keyTs) @@ -359,6 +349,10 @@ func testRecordLabelQueryExecutionsSync(t *testing.T, ds *mock.Store, pool fleet require.NoError(t, err) require.Equal(t, 0, n) + n, err = redigo.Int(conn.Do("ZCARD", labelMembershipActiveHostIDsKey)) + require.NoError(t, err) + require.Equal(t, 0, n) + labelReportedAt = task.GetHostLabelReportedAt(ctx, host) require.True(t, labelReportedAt.Equal(now)) } @@ -380,6 +374,12 @@ func testRecordLabelQueryExecutionsAsync(t *testing.T, ds *mock.Store, pool flee Datastore: ds, Pool: pool, AsyncEnabled: true, + + InsertBatch: 3, + UpdateBatch: 3, + DeleteBatch: 3, + RedisPopCount: 3, + RedisScanKeysCount: 10, } labelReportedAt := task.GetHostLabelReportedAt(ctx, host) @@ -389,7 +389,7 @@ func testRecordLabelQueryExecutionsAsync(t *testing.T, ds *mock.Store, pool flee require.NoError(t, err) require.False(t, ds.RecordLabelQueryExecutionsFuncInvoked) - conn := pool.Get() + conn := redis.ConfigureDoer(pool, pool.Get()) defer conn.Close() defer conn.Do("DEL", keySet, keyTs) @@ -402,12 +402,31 @@ func testRecordLabelQueryExecutionsAsync(t *testing.T, ds *mock.Store, pool flee require.NoError(t, err) require.Equal(t, now.Unix(), ts) + count, err := redigo.Int(conn.Do("ZCARD", labelMembershipActiveHostIDsKey)) + require.NoError(t, err) + require.Equal(t, 1, count) + tsActive, err := redigo.Int64(conn.Do("ZSCORE", labelMembershipActiveHostIDsKey, host.ID)) + require.NoError(t, err) + require.Equal(t, tsActive, ts) + labelReportedAt = task.GetHostLabelReportedAt(ctx, host) // because we transition via unix epoch (seconds), not exactly equal require.WithinDuration(t, now, labelReportedAt, time.Second) // host's LabelUpdatedAt field hasn't been updated yet, because the label // results are in redis, not in mysql yet. require.True(t, host.LabelUpdatedAt.Equal(lastYear)) + + // running the collector removes the host from the active set + var stats collectorExecStats + err = task.collectLabelQueryExecutions(ctx, ds, pool, &stats) + require.NoError(t, err) + require.Equal(t, 1, stats.Keys) + require.Equal(t, 0, stats.Items) // zero because we cleared the host's set with ZPOPMIN above + require.False(t, stats.Failed) + + count, err = redigo.Int(conn.Do("ZCARD", labelMembershipActiveHostIDsKey)) + require.NoError(t, err) + require.Equal(t, 0, count) } func createHosts(t *testing.T, ds fleet.Datastore, count int, ts time.Time) []uint {