fleet/server/service/async/async_policy.go
Victor Lyuboslavsky 8af94af14b
Removed duplicate FlippingPoliciesForHost DB calls (#42845)
<!-- Add the related story/sub-task/bug number, like Resolves #123, or
remove if NA -->
**Related issue:** Resolves #42836 

This is another hot path optimization.

## Before

When a host submits policy results via `SubmitDistributedQueryResults`,
the system needed to determine which policies "flipped" (changed from
passing to failing or vice versa). Each consumer computed this
independently:

```
SubmitDistributedQueryResults(policyResults)
  |
  +-- processScriptsForNewlyFailingPolicies
  |     filter to failing policies with scripts
  |     BUILD SUBSET of results
  |     CALL FlippingPoliciesForHost(subset)          <-- DB query #1
  |     convert result to set, filter, queue scripts
  |
  +-- processSoftwareForNewlyFailingPolicies
  |     filter to failing policies with installers
  |     BUILD SUBSET of results
  |     CALL FlippingPoliciesForHost(subset)          <-- DB query #2
  |     convert result to set, filter, queue installs
  |
  +-- processVPPForNewlyFailingPolicies
  |     filter to failing policies with VPP apps
  |     BUILD SUBSET of results
  |     CALL FlippingPoliciesForHost(subset)          <-- DB query #3
  |     convert result to set, filter, queue VPP
  |
  +-- webhook filtering
  |     filter to webhook-enabled policies
  |     CALL FlippingPoliciesForHost(subset)          <-- DB query #4
  |     register flipped policies in Redis
  |
  +-- RecordPolicyQueryExecutions
        CALL FlippingPoliciesForHost(all results)     <-- DB query #5
        reset attempt counters for newly passing
        INSERT/UPDATE policy_membership
```

Each `FlippingPoliciesForHost` call runs `SELECT policy_id, passes FROM
policy_membership WHERE host_id = ? AND policy_id IN (?)`. All 5 queries
hit the same table for the same host before `policy_membership` is
updated, so they all see identical state.

Each consumer also built intermediate maps to narrow down to its subset
before calling `FlippingPoliciesForHost`, then converted the result into
yet another set for filtering. This meant 3-4 temporary maps per
consumer.

## After

```
SubmitDistributedQueryResults(policyResults)
  |
  CALL FlippingPoliciesForHost(all results)           <-- single DB query
  build newFailingSet, normalize newPassing
  |
  +-- processScriptsForNewlyFailingPolicies
  |     filter to failing policies with scripts
  |     CHECK newFailingSet (in-memory map lookup)
  |     queue scripts
  |
  +-- processSoftwareForNewlyFailingPolicies
  |     filter to failing policies with installers
  |     CHECK newFailingSet (in-memory map lookup)
  |     queue installs
  |
  +-- processVPPForNewlyFailingPolicies
  |     filter to failing policies with VPP apps
  |     CHECK newFailingSet (in-memory map lookup)
  |     queue VPP
  |
  +-- webhook filtering
  |     filter to webhook-enabled policies
  |     FILTER newFailing/newPassing by policy IDs (in-memory)
  |     register flipped policies in Redis
  |
  +-- RecordPolicyQueryExecutions
        USE pre-computed newPassing (skip DB query)
        reset attempt counters for newly passing
        INSERT/UPDATE policy_membership
```

The intermediate subset maps and per-consumer set conversions are
removed. Each process function goes directly from "policies with
associated automation" to "is this policy in newFailingSet?" in a single
map lookup.

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.

## Testing

- [x] Added/updated automated tests
- [x] QA'd all new/changed functionality manually


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Performance Improvements**
* Reduced redundant database queries during policy result submissions by
computing flipping policies once per host check-in instead of multiple
times.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-04-06 10:11:07 -05:00

277 lines
8.3 KiB
Go

package async
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
redigo "github.com/gomodule/redigo/redis"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
const (
policyPassHostIDsKey = "policy_pass:active_host_ids"
policyPassHostKey = "policy_pass:{%d}"
policyPassReportedKey = "policy_pass_reported:{%d}"
policyPassKeysMinTTL = 7 * 24 * time.Hour // 1 week
)
// redis list will be LTRIM'd if there are more policy IDs than this.
var maxRedisPolicyResultsPerHost = 1000
func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool, newlyPassingPolicyIDs []uint) error {
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
if !cfg.Enabled {
host.PolicyUpdatedAt = ts
return t.datastore.RecordPolicyQueryExecutions(ctx, host, results, ts, deferred, newlyPassingPolicyIDs)
}
keyList := fmt.Sprintf(policyPassHostKey, host.ID)
keyTs := fmt.Sprintf(policyPassReportedKey, host.ID)
// set an expiration on both keys (list and ts), ensuring that a deleted host
// (eventually) does not use any redis space. Ensure that TTL is reasonably
// big to avoid deleting information that hasn't been collected yet - 1 week
// or 10 * the collector interval, whichever is biggest.
//
// This means that it will only expire if that host hasn't reported policies
// during that (TTL) time (each time it does report, the TTL is reset), and
// the collector will have plenty of time to run (multiple times) to try to
// persist all the data in mysql.
ttl := policyPassKeysMinTTL
if maxTTL := 10 * cfg.CollectInterval; maxTTL > ttl {
ttl = maxTTL
}
// There are two versions of the script (1) and (2).
// Script (1) is used when the are no policy results to report and
// script (2) is used when there are policy results.
// (1)
// KEYS[1]: keyTs (policyPassReportedKey)
// ARGV[1]: timestamp for "reported at"
// ARGV[2]: ttl for the key
scriptSrc := `
redis.call('SET', KEYS[1], ARGV[1])
return redis.call('EXPIRE', KEYS[1], ARGV[2])
`
keyCount := 1
args := make(redigo.Args, 0, 3)
args = args.Add(keyTs, ts.Unix(), int(ttl.Seconds()))
if len(results) > 0 {
// (2)
// KEYS[1]: keyList (policyPassHostKey)
// KEYS[2]: keyTs (policyPassReportedKey)
// ARGV[1]: timestamp for "reported at"
// ARGV[2]: max policy results to keep per host (list is trimmed to that size)
// ARGV[3]: ttl for both keys
// ARGV[4..]: policy_id=pass entries to LPUSH to the list
keyCount = 2
scriptSrc = `
redis.call('LPUSH', KEYS[1], unpack(ARGV, 4))
redis.call('LTRIM', KEYS[1], 0, ARGV[2])
redis.call('EXPIRE', KEYS[1], ARGV[3])
redis.call('SET', KEYS[2], ARGV[1])
return redis.call('EXPIRE', KEYS[2], ARGV[3])
`
// convert results to LPUSH arguments, store as policy_id=1 for pass,
// policy_id=-1 for fail, policy_id=0 for null result.
args = make(redigo.Args, 0, 5+len(results))
args = args.Add(keyList, keyTs, ts.Unix(), maxRedisPolicyResultsPerHost, int(ttl.Seconds()))
for k, v := range results {
pass := 0
if v != nil {
if *v {
pass = 1
} else {
pass = -1
}
}
args = args.Add(fmt.Sprintf("%d=%d", k, pass))
}
}
script := redigo.NewScript(keyCount, scriptSrc)
conn := t.pool.Get()
defer conn.Close()
if err := redis.BindConn(t.pool, conn, keyList, keyTs); err != nil {
return ctxerr.Wrap(ctx, err, "bind redis connection")
}
if _, err := script.Do(conn, args...); err != nil {
return ctxerr.Wrap(ctx, err, "run redis script")
}
// Storing the host id in the set of active host IDs for policy membership
// outside of the redis script because in Redis Cluster mode the key may not
// live on the same node as the host's keys. At the same time, purge any
// entry in the set that is older than now - TTL.
if _, err := storePurgeActiveHostID(t.pool, policyPassHostIDsKey, host.ID, ts, ts.Add(-ttl)); err != nil {
return ctxerr.Wrap(ctx, err, "store active host id")
}
return nil
}
func (t *Task) collectPolicyQueryExecutions(ctx context.Context, ds fleet.Datastore, pool fleet.RedisPool, stats *collectorExecStats) error {
// Create a root span for this async collection task if OTEL is enabled
if t.otelEnabled {
tracer := otel.Tracer("async")
var span trace.Span
ctx, span = tracer.Start(ctx, "async.collect_policy_query_executions",
trace.WithAttributes(
attribute.String("async.task", "policy_membership"),
),
)
defer span.End()
}
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
hosts, err := loadActiveHostIDs(pool, policyPassHostIDsKey, cfg.RedisScanKeysCount)
if err != nil {
return ctxerr.Wrap(ctx, err, "load active host ids")
}
stats.Keys = len(hosts)
// need to use a script as the RPOP command only supports a COUNT since
// 6.2. Because we use LTRIM when inserting, we know the total number
// of results is at most maxRedisPolicyResultsPerHost, so it is capped
// and can be returned in one go.
script := redigo.NewScript(1, `
local res = redis.call('LRANGE', KEYS[1], 0, -1)
redis.call('DEL', KEYS[1])
return res
`)
getKeyTuples := func(hostID uint) (inserts []fleet.PolicyMembershipResult, err error) {
keyList := fmt.Sprintf(policyPassHostKey, hostID)
conn := redis.ConfigureDoer(pool, pool.Get())
defer conn.Close()
stats.RedisCmds++
res, err := redigo.Strings(script.Do(conn, keyList))
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "redis LRANGE script")
}
inserts = make([]fleet.PolicyMembershipResult, 0, len(res))
stats.Items += len(res)
for _, item := range res {
parts := strings.Split(item, "=")
if len(parts) != 2 {
continue
}
var tup fleet.PolicyMembershipResult
if id, _ := strconv.ParseUint(parts[0], 10, 32); id > 0 {
tup.HostID = hostID
tup.PolicyID = uint(id)
switch parts[1] {
case "1":
tup.Passes = ptr.Bool(true)
case "-1":
tup.Passes = ptr.Bool(false)
case "0":
tup.Passes = nil
default:
continue
}
inserts = append(inserts, tup)
}
}
return inserts, nil
}
runInsertBatch := func(batch []fleet.PolicyMembershipResult) error {
stats.Inserts++
return ds.AsyncBatchInsertPolicyMembership(ctx, batch)
}
runUpdateBatch := func(ids []uint, ts time.Time) error {
stats.Updates++
return ds.AsyncBatchUpdatePolicyTimestamp(ctx, ids, ts)
}
insertBatch := make([]fleet.PolicyMembershipResult, 0, cfg.InsertBatch)
for _, host := range hosts {
hid := host.HostID
ins, err := getKeyTuples(hid)
if err != nil {
return err
}
insertBatch = append(insertBatch, ins...)
if len(insertBatch) >= cfg.InsertBatch {
if err := runInsertBatch(insertBatch); err != nil {
return err
}
insertBatch = insertBatch[:0]
}
}
// process any remaining batch that did not reach the batchSize limit in the
// loop.
if len(insertBatch) > 0 {
if err := runInsertBatch(insertBatch); err != nil {
return err
}
}
if len(hosts) > 0 {
hostIDs := make([]uint, len(hosts))
for i, host := range hosts {
hostIDs[i] = host.HostID
}
ts := t.clock.Now()
updateBatch := make([]uint, cfg.UpdateBatch)
for {
n := copy(updateBatch, hostIDs)
if n == 0 {
break
}
if err := runUpdateBatch(updateBatch[:n], ts); err != nil {
return err
}
hostIDs = hostIDs[n:]
}
// batch-remove any host ID from the active set that still has its score to
// the initial value, so that the active set does not keep all (potentially
// 100K+) host IDs to process at all times - only those with reported
// results to process.
if _, err := removeProcessedHostIDs(pool, policyPassHostIDsKey, hosts); err != nil {
return ctxerr.Wrap(ctx, err, "remove processed host ids")
}
}
return nil
}
func (t *Task) GetHostPolicyReportedAt(ctx context.Context, host *fleet.Host) time.Time {
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
if cfg.Enabled {
conn := redis.ConfigureDoer(t.pool, t.pool.Get())
defer conn.Close()
key := fmt.Sprintf(policyPassReportedKey, host.ID)
epoch, err := redigo.Int64(conn.Do("GET", key))
if err == nil {
if reported := time.Unix(epoch, 0); reported.After(host.PolicyUpdatedAt) {
return reported
}
}
}
return host.PolicyUpdatedAt
}