mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 13:37:30 +00:00
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #42836 This is another hot path optimization. ## Before When a host submits policy results via `SubmitDistributedQueryResults`, the system needed to determine which policies "flipped" (changed from passing to failing or vice versa). Each consumer computed this independently: ``` SubmitDistributedQueryResults(policyResults) | +-- processScriptsForNewlyFailingPolicies | filter to failing policies with scripts | BUILD SUBSET of results | CALL FlippingPoliciesForHost(subset) <-- DB query #1 | convert result to set, filter, queue scripts | +-- processSoftwareForNewlyFailingPolicies | filter to failing policies with installers | BUILD SUBSET of results | CALL FlippingPoliciesForHost(subset) <-- DB query #2 | convert result to set, filter, queue installs | +-- processVPPForNewlyFailingPolicies | filter to failing policies with VPP apps | BUILD SUBSET of results | CALL FlippingPoliciesForHost(subset) <-- DB query #3 | convert result to set, filter, queue VPP | +-- webhook filtering | filter to webhook-enabled policies | CALL FlippingPoliciesForHost(subset) <-- DB query #4 | register flipped policies in Redis | +-- RecordPolicyQueryExecutions CALL FlippingPoliciesForHost(all results) <-- DB query #5 reset attempt counters for newly passing INSERT/UPDATE policy_membership ``` Each `FlippingPoliciesForHost` call runs `SELECT policy_id, passes FROM policy_membership WHERE host_id = ? AND policy_id IN (?)`. All 5 queries hit the same table for the same host before `policy_membership` is updated, so they all see identical state. Each consumer also built intermediate maps to narrow down to its subset before calling `FlippingPoliciesForHost`, then converted the result into yet another set for filtering. This meant 3-4 temporary maps per consumer. ## After ``` SubmitDistributedQueryResults(policyResults) | CALL FlippingPoliciesForHost(all results) <-- single DB query build newFailingSet, normalize newPassing | +-- processScriptsForNewlyFailingPolicies | filter to failing policies with scripts | CHECK newFailingSet (in-memory map lookup) | queue scripts | +-- processSoftwareForNewlyFailingPolicies | filter to failing policies with installers | CHECK newFailingSet (in-memory map lookup) | queue installs | +-- processVPPForNewlyFailingPolicies | filter to failing policies with VPP apps | CHECK newFailingSet (in-memory map lookup) | queue VPP | +-- webhook filtering | filter to webhook-enabled policies | FILTER newFailing/newPassing by policy IDs (in-memory) | register flipped policies in Redis | +-- RecordPolicyQueryExecutions USE pre-computed newPassing (skip DB query) reset attempt counters for newly passing INSERT/UPDATE policy_membership ``` The intermediate subset maps and per-consumer set conversions are removed. Each process function goes directly from "policies with associated automation" to "is this policy in newFailingSet?" in a single map lookup. # Checklist for submitter If some of the following don't apply, delete the relevant line. - [x] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. ## Testing - [x] Added/updated automated tests - [x] QA'd all new/changed functionality manually <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Performance Improvements** * Reduced redundant database queries during policy result submissions by computing flipping policies once per host check-in instead of multiple times. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
277 lines
8.3 KiB
Go
277 lines
8.3 KiB
Go
package async
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/server/config"
|
|
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
|
"github.com/fleetdm/fleet/v4/server/datastore/redis"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
redigo "github.com/gomodule/redigo/redis"
|
|
"go.opentelemetry.io/otel"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/trace"
|
|
)
|
|
|
|
const (
|
|
policyPassHostIDsKey = "policy_pass:active_host_ids"
|
|
policyPassHostKey = "policy_pass:{%d}"
|
|
policyPassReportedKey = "policy_pass_reported:{%d}"
|
|
policyPassKeysMinTTL = 7 * 24 * time.Hour // 1 week
|
|
)
|
|
|
|
// redis list will be LTRIM'd if there are more policy IDs than this.
|
|
var maxRedisPolicyResultsPerHost = 1000
|
|
|
|
func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool, newlyPassingPolicyIDs []uint) error {
|
|
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
|
|
if !cfg.Enabled {
|
|
host.PolicyUpdatedAt = ts
|
|
return t.datastore.RecordPolicyQueryExecutions(ctx, host, results, ts, deferred, newlyPassingPolicyIDs)
|
|
}
|
|
|
|
keyList := fmt.Sprintf(policyPassHostKey, host.ID)
|
|
keyTs := fmt.Sprintf(policyPassReportedKey, host.ID)
|
|
|
|
// set an expiration on both keys (list and ts), ensuring that a deleted host
|
|
// (eventually) does not use any redis space. Ensure that TTL is reasonably
|
|
// big to avoid deleting information that hasn't been collected yet - 1 week
|
|
// or 10 * the collector interval, whichever is biggest.
|
|
//
|
|
// This means that it will only expire if that host hasn't reported policies
|
|
// during that (TTL) time (each time it does report, the TTL is reset), and
|
|
// the collector will have plenty of time to run (multiple times) to try to
|
|
// persist all the data in mysql.
|
|
ttl := policyPassKeysMinTTL
|
|
if maxTTL := 10 * cfg.CollectInterval; maxTTL > ttl {
|
|
ttl = maxTTL
|
|
}
|
|
|
|
// There are two versions of the script (1) and (2).
|
|
// Script (1) is used when the are no policy results to report and
|
|
// script (2) is used when there are policy results.
|
|
|
|
// (1)
|
|
// KEYS[1]: keyTs (policyPassReportedKey)
|
|
// ARGV[1]: timestamp for "reported at"
|
|
// ARGV[2]: ttl for the key
|
|
scriptSrc := `
|
|
redis.call('SET', KEYS[1], ARGV[1])
|
|
return redis.call('EXPIRE', KEYS[1], ARGV[2])
|
|
`
|
|
keyCount := 1
|
|
args := make(redigo.Args, 0, 3)
|
|
args = args.Add(keyTs, ts.Unix(), int(ttl.Seconds()))
|
|
|
|
if len(results) > 0 {
|
|
// (2)
|
|
// KEYS[1]: keyList (policyPassHostKey)
|
|
// KEYS[2]: keyTs (policyPassReportedKey)
|
|
// ARGV[1]: timestamp for "reported at"
|
|
// ARGV[2]: max policy results to keep per host (list is trimmed to that size)
|
|
// ARGV[3]: ttl for both keys
|
|
// ARGV[4..]: policy_id=pass entries to LPUSH to the list
|
|
keyCount = 2
|
|
scriptSrc = `
|
|
redis.call('LPUSH', KEYS[1], unpack(ARGV, 4))
|
|
redis.call('LTRIM', KEYS[1], 0, ARGV[2])
|
|
redis.call('EXPIRE', KEYS[1], ARGV[3])
|
|
redis.call('SET', KEYS[2], ARGV[1])
|
|
return redis.call('EXPIRE', KEYS[2], ARGV[3])
|
|
`
|
|
// convert results to LPUSH arguments, store as policy_id=1 for pass,
|
|
// policy_id=-1 for fail, policy_id=0 for null result.
|
|
args = make(redigo.Args, 0, 5+len(results))
|
|
args = args.Add(keyList, keyTs, ts.Unix(), maxRedisPolicyResultsPerHost, int(ttl.Seconds()))
|
|
for k, v := range results {
|
|
pass := 0
|
|
if v != nil {
|
|
if *v {
|
|
pass = 1
|
|
} else {
|
|
pass = -1
|
|
}
|
|
}
|
|
args = args.Add(fmt.Sprintf("%d=%d", k, pass))
|
|
}
|
|
}
|
|
|
|
script := redigo.NewScript(keyCount, scriptSrc)
|
|
|
|
conn := t.pool.Get()
|
|
defer conn.Close()
|
|
if err := redis.BindConn(t.pool, conn, keyList, keyTs); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "bind redis connection")
|
|
}
|
|
|
|
if _, err := script.Do(conn, args...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "run redis script")
|
|
}
|
|
|
|
// Storing the host id in the set of active host IDs for policy membership
|
|
// outside of the redis script because in Redis Cluster mode the key may not
|
|
// live on the same node as the host's keys. At the same time, purge any
|
|
// entry in the set that is older than now - TTL.
|
|
if _, err := storePurgeActiveHostID(t.pool, policyPassHostIDsKey, host.ID, ts, ts.Add(-ttl)); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "store active host id")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *Task) collectPolicyQueryExecutions(ctx context.Context, ds fleet.Datastore, pool fleet.RedisPool, stats *collectorExecStats) error {
|
|
// Create a root span for this async collection task if OTEL is enabled
|
|
if t.otelEnabled {
|
|
tracer := otel.Tracer("async")
|
|
var span trace.Span
|
|
ctx, span = tracer.Start(ctx, "async.collect_policy_query_executions",
|
|
trace.WithAttributes(
|
|
attribute.String("async.task", "policy_membership"),
|
|
),
|
|
)
|
|
defer span.End()
|
|
}
|
|
|
|
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
|
|
|
|
hosts, err := loadActiveHostIDs(pool, policyPassHostIDsKey, cfg.RedisScanKeysCount)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "load active host ids")
|
|
}
|
|
stats.Keys = len(hosts)
|
|
|
|
// need to use a script as the RPOP command only supports a COUNT since
|
|
// 6.2. Because we use LTRIM when inserting, we know the total number
|
|
// of results is at most maxRedisPolicyResultsPerHost, so it is capped
|
|
// and can be returned in one go.
|
|
script := redigo.NewScript(1, `
|
|
local res = redis.call('LRANGE', KEYS[1], 0, -1)
|
|
redis.call('DEL', KEYS[1])
|
|
return res
|
|
`)
|
|
|
|
getKeyTuples := func(hostID uint) (inserts []fleet.PolicyMembershipResult, err error) {
|
|
keyList := fmt.Sprintf(policyPassHostKey, hostID)
|
|
conn := redis.ConfigureDoer(pool, pool.Get())
|
|
defer conn.Close()
|
|
|
|
stats.RedisCmds++
|
|
res, err := redigo.Strings(script.Do(conn, keyList))
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "redis LRANGE script")
|
|
}
|
|
|
|
inserts = make([]fleet.PolicyMembershipResult, 0, len(res))
|
|
stats.Items += len(res)
|
|
for _, item := range res {
|
|
parts := strings.Split(item, "=")
|
|
if len(parts) != 2 {
|
|
continue
|
|
}
|
|
|
|
var tup fleet.PolicyMembershipResult
|
|
if id, _ := strconv.ParseUint(parts[0], 10, 32); id > 0 {
|
|
tup.HostID = hostID
|
|
tup.PolicyID = uint(id)
|
|
switch parts[1] {
|
|
case "1":
|
|
tup.Passes = ptr.Bool(true)
|
|
case "-1":
|
|
tup.Passes = ptr.Bool(false)
|
|
case "0":
|
|
tup.Passes = nil
|
|
default:
|
|
continue
|
|
}
|
|
inserts = append(inserts, tup)
|
|
}
|
|
}
|
|
return inserts, nil
|
|
}
|
|
|
|
runInsertBatch := func(batch []fleet.PolicyMembershipResult) error {
|
|
stats.Inserts++
|
|
return ds.AsyncBatchInsertPolicyMembership(ctx, batch)
|
|
}
|
|
|
|
runUpdateBatch := func(ids []uint, ts time.Time) error {
|
|
stats.Updates++
|
|
return ds.AsyncBatchUpdatePolicyTimestamp(ctx, ids, ts)
|
|
}
|
|
|
|
insertBatch := make([]fleet.PolicyMembershipResult, 0, cfg.InsertBatch)
|
|
for _, host := range hosts {
|
|
hid := host.HostID
|
|
ins, err := getKeyTuples(hid)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
insertBatch = append(insertBatch, ins...)
|
|
|
|
if len(insertBatch) >= cfg.InsertBatch {
|
|
if err := runInsertBatch(insertBatch); err != nil {
|
|
return err
|
|
}
|
|
insertBatch = insertBatch[:0]
|
|
}
|
|
}
|
|
|
|
// process any remaining batch that did not reach the batchSize limit in the
|
|
// loop.
|
|
if len(insertBatch) > 0 {
|
|
if err := runInsertBatch(insertBatch); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if len(hosts) > 0 {
|
|
hostIDs := make([]uint, len(hosts))
|
|
for i, host := range hosts {
|
|
hostIDs[i] = host.HostID
|
|
}
|
|
|
|
ts := t.clock.Now()
|
|
updateBatch := make([]uint, cfg.UpdateBatch)
|
|
for {
|
|
n := copy(updateBatch, hostIDs)
|
|
if n == 0 {
|
|
break
|
|
}
|
|
if err := runUpdateBatch(updateBatch[:n], ts); err != nil {
|
|
return err
|
|
}
|
|
hostIDs = hostIDs[n:]
|
|
}
|
|
|
|
// batch-remove any host ID from the active set that still has its score to
|
|
// the initial value, so that the active set does not keep all (potentially
|
|
// 100K+) host IDs to process at all times - only those with reported
|
|
// results to process.
|
|
if _, err := removeProcessedHostIDs(pool, policyPassHostIDsKey, hosts); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "remove processed host ids")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *Task) GetHostPolicyReportedAt(ctx context.Context, host *fleet.Host) time.Time {
|
|
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
|
|
|
|
if cfg.Enabled {
|
|
conn := redis.ConfigureDoer(t.pool, t.pool.Get())
|
|
defer conn.Close()
|
|
|
|
key := fmt.Sprintf(policyPassReportedKey, host.ID)
|
|
epoch, err := redigo.Int64(conn.Do("GET", key))
|
|
if err == nil {
|
|
if reported := time.Unix(epoch, 0); reported.After(host.PolicyUpdatedAt) {
|
|
return reported
|
|
}
|
|
}
|
|
}
|
|
return host.PolicyUpdatedAt
|
|
}
|