Removed duplicate FlippingPoliciesForHost DB calls (#42845)

<!-- Add the related story/sub-task/bug number, like Resolves #123, or
remove if NA -->
**Related issue:** Resolves #42836 

This is another hot path optimization.

## Before

When a host submits policy results via `SubmitDistributedQueryResults`,
the system needed to determine which policies "flipped" (changed from
passing to failing or vice versa). Each consumer computed this
independently:

```
SubmitDistributedQueryResults(policyResults)
  |
  +-- processScriptsForNewlyFailingPolicies
  |     filter to failing policies with scripts
  |     BUILD SUBSET of results
  |     CALL FlippingPoliciesForHost(subset)          <-- DB query #1
  |     convert result to set, filter, queue scripts
  |
  +-- processSoftwareForNewlyFailingPolicies
  |     filter to failing policies with installers
  |     BUILD SUBSET of results
  |     CALL FlippingPoliciesForHost(subset)          <-- DB query #2
  |     convert result to set, filter, queue installs
  |
  +-- processVPPForNewlyFailingPolicies
  |     filter to failing policies with VPP apps
  |     BUILD SUBSET of results
  |     CALL FlippingPoliciesForHost(subset)          <-- DB query #3
  |     convert result to set, filter, queue VPP
  |
  +-- webhook filtering
  |     filter to webhook-enabled policies
  |     CALL FlippingPoliciesForHost(subset)          <-- DB query #4
  |     register flipped policies in Redis
  |
  +-- RecordPolicyQueryExecutions
        CALL FlippingPoliciesForHost(all results)     <-- DB query #5
        reset attempt counters for newly passing
        INSERT/UPDATE policy_membership
```

Each `FlippingPoliciesForHost` call runs `SELECT policy_id, passes FROM
policy_membership WHERE host_id = ? AND policy_id IN (?)`. All 5 queries
hit the same table for the same host before `policy_membership` is
updated, so they all see identical state.

Each consumer also built intermediate maps to narrow down to its subset
before calling `FlippingPoliciesForHost`, then converted the result into
yet another set for filtering. This meant 3-4 temporary maps per
consumer.

## After

```
SubmitDistributedQueryResults(policyResults)
  |
  CALL FlippingPoliciesForHost(all results)           <-- single DB query
  build newFailingSet, normalize newPassing
  |
  +-- processScriptsForNewlyFailingPolicies
  |     filter to failing policies with scripts
  |     CHECK newFailingSet (in-memory map lookup)
  |     queue scripts
  |
  +-- processSoftwareForNewlyFailingPolicies
  |     filter to failing policies with installers
  |     CHECK newFailingSet (in-memory map lookup)
  |     queue installs
  |
  +-- processVPPForNewlyFailingPolicies
  |     filter to failing policies with VPP apps
  |     CHECK newFailingSet (in-memory map lookup)
  |     queue VPP
  |
  +-- webhook filtering
  |     filter to webhook-enabled policies
  |     FILTER newFailing/newPassing by policy IDs (in-memory)
  |     register flipped policies in Redis
  |
  +-- RecordPolicyQueryExecutions
        USE pre-computed newPassing (skip DB query)
        reset attempt counters for newly passing
        INSERT/UPDATE policy_membership
```

The intermediate subset maps and per-consumer set conversions are
removed. Each process function goes directly from "policies with
associated automation" to "is this policy in newFailingSet?" in a single
map lookup.

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.

## Testing

- [x] Added/updated automated tests
- [x] QA'd all new/changed functionality manually


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Performance Improvements**
* Reduced redundant database queries during policy result submissions by
computing flipping policies once per host check-in instead of multiple
times.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Victor Lyuboslavsky 2026-04-06 10:11:07 -05:00 committed by GitHub
parent a7e4557066
commit 8af94af14b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 286 additions and 276 deletions

View file

@ -0,0 +1 @@
- Reduced redundant database queries during policy result submission by computing flipping policies once per host check-in instead of multiple times.

View file

@ -306,7 +306,7 @@ func testConditionalAccessBypassDeviceWithBlockingPolicy(t *testing.T, ds *Datas
require.NoError(t, err)
// Record a failing result for this policy on the host
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(false)}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
// Bypass should fail because the host has a failing CA-enabled policy
@ -365,7 +365,7 @@ func testConditionalAccessBypassAllowedWithNonCAFailingCriticalPolicy(t *testing
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{
caPolicy.ID: ptr.Bool(true), // passing
nonCAPolicy.ID: ptr.Bool(false), // failing
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
// Bypass must succeed: the only failing policy is not CA-enabled
@ -415,7 +415,7 @@ func testConditionalAccessBypassAllowedWithCAEnabledNonCriticalPolicy(t *testing
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{
nonCriticalCAPolicy.ID: ptr.Bool(false), // failing
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
// Bypass must succeed: policy is CA-enabled but not critical

View file

@ -4101,8 +4101,8 @@ func testHostsListByPolicy(t *testing.T, ds *Datastore) {
require.Len(t, hosts, 0)
// Make one host pass the policy and another not pass
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{1: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{1: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{1: new(true)}, time.Now(), false, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{1: new(false)}, time.Now(), false, nil))
hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{PolicyIDFilter: &p.ID, PolicyResponseFilter: ptr.Bool(true)}, 1)
require.Len(t, hosts, 1)
@ -5105,18 +5105,18 @@ func testHostsListFailingPolicies(t *testing.T, ds *Datastore) {
assert.Zero(t, *h2.HostIssues.CriticalVulnerabilitiesCount)
assert.Zero(t, h2.HostIssues.TotalIssuesCount)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(true)}, time.Now(), false, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(false), p2.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(false), p2.ID: new(false)}, time.Now(), false, nil))
checkHostIssues(t, ds, hosts, filter, h2.ID, 2)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(true), p2.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(true), p2.ID: new(false)}, time.Now(), false, nil))
checkHostIssues(t, ds, hosts, filter, h2.ID, 1)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(true), p2.ID: new(true)}, time.Now(), false, nil))
checkHostIssues(t, ds, hosts, filter, h2.ID, 0)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(false)}, time.Now(), false, nil))
checkHostIssues(t, ds, hosts, filter, h1.ID, 1)
checkHostIssuesWithOpts(t, ds, filter, h1.ID, fleet.HostListOptions{DisableIssues: true}, 0)
@ -5183,8 +5183,8 @@ func testHostsReadsLessRows(t *testing.T, ds *Datastore) {
})
require.NoError(t, err)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(true)}, time.Now(), false, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(false)}, time.Now(), false, nil))
prevRead := getReads(t, ds)
h1WithExtras, err := ds.Host(context.Background(), h1.ID)
@ -8969,7 +8969,7 @@ func testHostsDeleteHosts(t *testing.T, ds *Datastore) {
_, err = ds.writer(context.Background()).Exec(`INSERT INTO query_results (host_id, query_id, last_fetched, data) VALUES (?, ?, ?, ?)`, host.ID, policy.ID, time.Now(), `{"foo": "bar"}`)
require.NoError(t, err)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil))
// Update host_mdm.
err = ds.SetOrUpdateMDMData(context.Background(), host.ID, false, true, "foo.mdm.example.com", false, "", "", false)
require.NoError(t, err)
@ -9727,7 +9727,7 @@ func testFailingPoliciesCount(t *testing.T, ds *Datastore) {
for _, tc := range testCases {
if len(tc.policyEx) != 0 {
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, tc.host, tc.policyEx, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, tc.host, tc.policyEx, time.Now(), false, nil))
}
actual, err := ds.FailingPoliciesCount(ctx, tc.host)
require.NoError(t, err)
@ -9769,7 +9769,7 @@ func testHostsRecordNoPolicies(t *testing.T, ds *Datastore) {
assert.Zero(t, h2.HostIssues.TotalIssuesCount)
policyUpdatedAt := initialTime.Add(1 * time.Hour)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, nil, policyUpdatedAt, false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, nil, policyUpdatedAt, false, nil))
hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{}, 2)
require.Len(t, hosts, 2)
@ -9916,7 +9916,7 @@ func testHostOrder(t *testing.T, ds *Datastore) {
}
require.NoError(
t, ds.RecordPolicyQueryExecutions(
context.Background(), createdHosts[i], results, time.Now(), false,
context.Background(), createdHosts[i], results, time.Now(), false, nil,
),
)
}
@ -11791,8 +11791,8 @@ func testHostHealth(t *testing.T, ds *Datastore) {
failingPolicy, err := ds.NewGlobalPolicy(context.Background(), &u.ID, fleet.PolicyPayload{QueryID: &q.ID})
require.NoError(t, err)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h, map[uint]*bool{passingPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h, map[uint]*bool{failingPolicy.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h, map[uint]*bool{passingPolicy.ID: new(true)}, time.Now(), false, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h, map[uint]*bool{failingPolicy.ID: new(false)}, time.Now(), false, nil))
// set up vulnerable software
software := []fleet.Software{
@ -12269,7 +12269,7 @@ func testUpdateHostIssues(t *testing.T, ds *Datastore) {
require.NoError(
// RecordPolicyQueryExecutions should call UpdateHostIssuesFailingPolicies, so we don't have to
t, ds.RecordPolicyQueryExecutions(
context.Background(), hosts[i], results, time.Now(), false,
context.Background(), hosts[i], results, time.Now(), false, nil,
),
)
}

View file

@ -1604,9 +1604,9 @@ func testListHostsInLabelIssues(t *testing.T, ds *Datastore) {
assert.Zero(t, *h2.HostIssues.CriticalVulnerabilitiesCount)
assert.Zero(t, h2.HostIssues.TotalIssuesCount)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(true)}, time.Now(), false, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(false), p2.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(false), p2.ID: new(false)}, time.Now(), false, nil))
checkLabelHostIssues(t, ds, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 2, 0)
// Add a critical vulnerability
@ -1676,13 +1676,13 @@ func testListHostsInLabelIssues(t *testing.T, ds *Datastore) {
assert.NoError(t, ds.UpdateHostIssuesVulnerabilities(ctx))
checkLabelHostIssues(t, ds, l1.ID, filter, hosts[6].ID, fleet.HostListOptions{}, 0, 4)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(true), p2.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(true), p2.ID: new(false)}, time.Now(), false, nil))
checkLabelHostIssues(t, ds, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 1, 1)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(true), p2.ID: new(true)}, time.Now(), false, nil))
checkLabelHostIssues(t, ds, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 0, 1)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(false)}, time.Now(), false, nil))
checkLabelHostIssues(t, ds, l1.ID, filter, h1.ID, fleet.HostListOptions{}, 1, 1)
checkLabelHostIssues(t, ds, l1.ID, filter, h1.ID, fleet.HostListOptions{DisableIssues: true}, 0, 0)

View file

@ -608,12 +608,20 @@ func filterNotExecuted(results map[uint]*bool) map[uint]bool {
return filtered
}
func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error {
func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool, newlyPassingPolicyIDs []uint) error {
// Identify policies that flipped failing -> passing for this host using current incoming results.
// We compute this before updating policy_membership so we compare against the previous state.
_, newPassing, err := ds.FlippingPoliciesForHost(ctx, host.ID, results)
if err != nil {
return err
// When newlyPassingPolicyIDs is non-nil, the caller has already computed flipping policies
// (e.g. SubmitDistributedQueryResults computes it once for all consumers) so we reuse that result.
var newPassing []uint
if newlyPassingPolicyIDs != nil {
newPassing = newlyPassingPolicyIDs
} else {
var err error
_, newPassing, err = ds.FlippingPoliciesForHost(ctx, host.ID, results)
if err != nil {
return err
}
}
if len(newPassing) > 0 {
slices.Sort(newPassing)
@ -645,7 +653,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *flee
// semantically equivalent, even though here it processes a single host and
// in async mode it processes a batch of hosts).
err = ds.withTx(ctx, func(tx sqlx.ExtContext) error {
err := ds.withTx(ctx, func(tx sqlx.ExtContext) error {
if len(results) > 0 {
query := fmt.Sprintf(
`INSERT INTO policy_membership (updated_at, policy_id, host_id, passes)

View file

@ -460,14 +460,14 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) {
require.NotNil(t, p2.AuthorID)
assert.Equal(t, user1.ID, *p2.AuthorID)
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: new(true)}, time.Now(), deferred, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: new(true)}, time.Now(), deferred, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: nil}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: nil}, time.Now(), deferred, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: new(false)}, time.Now(), deferred, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: new(true)}, time.Now(), deferred, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: nil}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: nil}, time.Now(), deferred, nil))
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
@ -483,8 +483,8 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) {
assert.Equal(t, uint(0), policies[1].PassingHostCount)
assert.Equal(t, uint(0), policies[1].FailingHostCount)
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: new(false)}, time.Now(), deferred, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: new(false)}, time.Now(), deferred, nil))
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
@ -500,6 +500,30 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) {
assert.Equal(t, uint(0), policies[1].PassingHostCount)
assert.Equal(t, uint(1), policies[1].FailingHostCount)
// Test with pre-computed newlyPassingPolicyIDs (non-nil) to exercise the path where
// RecordPolicyQueryExecutions skips calling FlippingPoliciesForHost internally.
// host1 currently has p.ID=failing, so flipping to passing with pre-computed IDs.
require.NoError(t, ds.RecordPolicyQueryExecutions(
ctx, host1, map[uint]*bool{p.ID: new(true)}, time.Now(), deferred, []uint{p.ID},
))
// Also test with an empty (but non-nil) slice, which means "already computed, no newly passing".
require.NoError(t, ds.RecordPolicyQueryExecutions(
ctx, host2, map[uint]*bool{p2.ID: new(true)}, time.Now(), deferred, []uint{},
))
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
policies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 2)
// host1: p now passing again (was failing), host2: p still passing
assert.Equal(t, p.ID, policies[0].ID)
assert.Equal(t, uint(2), policies[0].PassingHostCount)
assert.Equal(t, uint(0), policies[0].FailingHostCount)
// host2: p2 now passing (was failing)
assert.Equal(t, p2.ID, policies[1].ID)
assert.Equal(t, uint(1), policies[1].PassingHostCount)
assert.Equal(t, uint(0), policies[1].FailingHostCount)
policy, err := ds.Policy(ctx, policies[0].ID)
require.NoError(t, err)
assert.Equal(t, policies[0], policy)
@ -553,9 +577,9 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) {
require.NoError(t, err)
// create some policy results
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host3, map[uint]*bool{t1pol.ID: ptr.Bool(true), p.ID: ptr.Bool(true), p2.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host4, map[uint]*bool{t2pol.ID: ptr.Bool(false), t2pol2.ID: ptr.Bool(true), p.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host5, map[uint]*bool{t2pol.ID: ptr.Bool(true), t2pol2.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host3, map[uint]*bool{t1pol.ID: new(true), p.ID: new(true), p2.ID: new(false)}, time.Now(), deferred, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host4, map[uint]*bool{t2pol.ID: new(false), t2pol2.ID: new(true), p.ID: new(false)}, time.Now(), deferred, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host5, map[uint]*bool{t2pol.ID: new(true), t2pol2.ID: new(true), p2.ID: new(true)}, time.Now(), deferred, nil))
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
@ -1077,7 +1101,7 @@ func testListMergedTeamPolicies(t *testing.T, ds *Datastore) {
&fleet.Host{OsqueryHostID: ptr.String("host1"), NodeKey: ptr.String(fmt.Sprint("host1", 1)), TeamID: nil})
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{gpol.ID: ptr.Bool(true)}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{gpol.ID: new(true)}, time.Now(), false, nil)
require.NoError(t, err)
err = ds.UpdateHostPolicyCounts(context.Background())
@ -1100,7 +1124,7 @@ func testListMergedTeamPolicies(t *testing.T, ds *Datastore) {
err = ds.AddHostsToTeam(ctx, fleet.NewAddHostsToTeamParams(&team1.ID, []uint{host.ID}))
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{team1policy.ID: ptr.Bool(true)}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{team1policy.ID: new(true)}, time.Now(), false, nil)
require.NoError(t, err)
err = ds.UpdateHostPolicyCounts(context.Background())
@ -1447,7 +1471,7 @@ func testPolicyQueriesForHost(t *testing.T, ds *Datastore) {
assert.Equal(t, q.Query, queries[fmt.Sprint(q.ID)])
// Team policy ran with failing result.
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{tp.ID: ptr.Bool(false), gp.ID: nil}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{tp.ID: new(false), gp.ID: nil}, time.Now(), false, nil))
policies, err := ds.ListPoliciesForHost(context.Background(), host1)
require.NoError(t, err)
@ -1490,7 +1514,7 @@ func testPolicyQueriesForHost(t *testing.T, ds *Datastore) {
assert.Equal(t, "", policies[0].Response)
// Global policy ran with passing result.
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{gp.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{gp.ID: new(true)}, time.Now(), false, nil))
policies, err = ds.ListPoliciesForHost(context.Background(), host2)
require.NoError(t, err)
@ -1511,7 +1535,7 @@ func testPolicyQueriesForHost(t *testing.T, ds *Datastore) {
require.NoError(t, err)
require.NoError(t,
ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{uint(id): nil}, //nolint:gosec // dismiss G115
time.Now(), false))
time.Now(), false, nil))
policies, err = ds.ListPoliciesForHost(context.Background(), host2)
require.NoError(t, err)
@ -1556,7 +1580,7 @@ func testPoliciesByID(t *testing.T, ds *Datastore) {
err = ds.SavePolicy(context.Background(), policy2, false, false)
require.NoError(t, err)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{policy1.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{policy1.ID: new(true)}, time.Now(), false, nil))
require.NoError(t, ds.UpdateHostPolicyCounts(context.Background()))
policiesByID, err := ds.PoliciesByID(context.Background(), []uint{1, 2})
@ -1634,10 +1658,10 @@ func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
})
require.NoError(t, err)
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: new(false), globalPolicy.ID: new(true)}, time.Now(), false, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: new(true), globalPolicy.ID: new(true)}, time.Now(), false, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: new(false), globalPolicy.ID: new(true)}, time.Now(), false, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: new(true), globalPolicy.ID: new(true)}, time.Now(), false, nil))
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
@ -1691,7 +1715,7 @@ func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
checkPassingCount(0, 0, 1, 1)
// team policies are removed if the host is re-enrolled without a team
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: new(true), globalPolicy.ID: new(true)}, time.Now(), false, nil))
checkPassingCount(1, 0, 2, 2)
// all host policies are removed when a host is re-enrolled
@ -2151,7 +2175,7 @@ func testApplyPolicySpecWithQueryPlatformChanges(t *testing.T, ds *Datastore) {
for _, pol := range polsByName {
res[pol.ID] = ptr.Bool(false)
}
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil)
require.NoError(t, err)
}
for _, h := range globalHosts {
@ -2159,7 +2183,7 @@ func testApplyPolicySpecWithQueryPlatformChanges(t *testing.T, ds *Datastore) {
for _, pol := range globalPolsByName {
res[pol.ID] = ptr.Bool(false)
}
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil)
require.NoError(t, err)
}
err = ds.UpdateHostPolicyCounts(ctx)
@ -2529,17 +2553,17 @@ func testCachedPolicyCountDeletesOnPolicyChange(t *testing.T, ds *Datastore) {
// teamHost and globalHost fail all policies
require.NoError(
t, ds.RecordPolicyQueryExecutions(
ctx, teamHost, map[uint]*bool{globalPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(false)}, time.Now(), false,
ctx, teamHost, map[uint]*bool{globalPolicy.ID: new(false), globalPolicy.ID: new(false)}, time.Now(), false, nil,
),
)
require.NoError(
t, ds.RecordPolicyQueryExecutions(
ctx, teamHost, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), teamPolicy.ID: ptr.Bool(false)}, time.Now(), false,
ctx, teamHost, map[uint]*bool{teamPolicy.ID: new(false), teamPolicy.ID: new(false)}, time.Now(), false, nil,
),
)
require.NoError(
t, ds.RecordPolicyQueryExecutions(
ctx, globalHost, map[uint]*bool{globalPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(false)}, time.Now(), false,
ctx, globalHost, map[uint]*bool{globalPolicy.ID: new(false), globalPolicy.ID: new(false)}, time.Now(), false, nil,
),
)
@ -2723,7 +2747,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) {
require.Empty(t, newPassing) // because this would be the first run.
// Record the above executions.
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil)
require.NoError(t, err)
// incoming policy 1 with passing result: no => yes
@ -2738,7 +2762,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) {
require.Equal(t, []uint{p1.ID}, newPassing)
// Record the above executions.
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil)
require.NoError(t, err)
// incoming policy 1 with passing result: yes => yes
@ -2753,7 +2777,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) {
require.Empty(t, newPassing)
// Record the above executions.
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil)
require.NoError(t, err)
// incoming policy 1 failed to execute: yes => no
@ -2777,7 +2801,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) {
require.Empty(t, newPassing)
// Record the above executions.
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil)
require.NoError(t, err)
// incoming pfailed again failed to execute: --- -> ---
@ -2798,7 +2822,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) {
require.Empty(t, newPassing)
// Record the above executions.
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil)
require.NoError(t, err)
// incoming policy 4 with first new failing result: --- => no
@ -2932,7 +2956,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
// also record a result for linux policy
res[polsByName["t2"].ID] = ptr.Bool(true)
}
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil)
require.NoError(t, err)
}
for i, h := range globalHosts {
@ -2943,7 +2967,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
// also record a result for linux policy
res[polsByName["g2"].ID] = ptr.Bool(true)
}
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil)
require.NoError(t, err)
}
@ -3066,9 +3090,9 @@ func testPolicyViolationDays(t *testing.T, ds *Datastore) {
require.NoError(t, ds.InitializePolicyViolationDays(ctx)) // sets starting violation count to zero
// initialize policy statuses: 1 failling, 2 passing
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[0], map[uint]*bool{pol.ID: ptr.Bool(false)}, then, false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[1], map[uint]*bool{pol.ID: ptr.Bool(true)}, then, false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[2], map[uint]*bool{pol.ID: ptr.Bool(true)}, then, false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[0], map[uint]*bool{pol.ID: new(false)}, then, false, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[1], map[uint]*bool{pol.ID: new(true)}, then, false, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[2], map[uint]*bool{pol.ID: new(true)}, then, false, nil))
// setup db for test: starting counts zero, more than 24h since last updated, one outstanding violation
require.NoError(t, setStatsTimestampDB(time.Now().Add(-25*time.Hour)))
@ -3094,7 +3118,7 @@ func testPolicyViolationDays(t *testing.T, ds *Datastore) {
// leave counts at zero for next test
// setup for test: starting count zero, more than 24h since last updated, add second outstanding violation
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil))
require.NoError(t, setStatsTimestampDB(time.Now().Add(-25*time.Hour)))
require.NoError(t, ds.IncrementPolicyViolationDays(ctx))
actual, possible, err = amountPolicyViolationDaysDB(ctx, ds.reader(ctx))
@ -3106,7 +3130,7 @@ func testPolicyViolationDays(t *testing.T, ds *Datastore) {
// leave counts at 2 actual and 3 possible for next test
// setup for test: starting counts at 2 actual and 3 possible, more than 24h since last updated, resolve one outstaning violation
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: new(true)}, time.Now(), false, nil))
require.NoError(t, setStatsTimestampDB(time.Now().Add(-25*time.Hour)))
require.NoError(t, ds.IncrementPolicyViolationDays(ctx))
actual, possible, err = amountPolicyViolationDaysDB(ctx, ds.reader(ctx))
@ -3196,7 +3220,7 @@ func testPolicyCleanupPolicyMembership(t *testing.T, ds *Datastore) {
polsByName["p2"].ID: ptr.Bool(true),
polsByName["p3"].ID: ptr.Bool(true),
}
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil)
require.NoError(t, err)
}
require.NoError(t, ds.writer(ctx).Get(&count, "select COUNT(*) from host_issues WHERE total_issues_count > 0"))
@ -3315,7 +3339,7 @@ func testDeleteAllPolicyMemberships(t *testing.T, ds *Datastore) {
host,
map[uint]*bool{globalPolicy.ID: ptr.Bool(false)},
time.Now(),
false,
false, nil,
)
require.NoError(t, err)
@ -3378,9 +3402,9 @@ func testOutdatedAutomationBatch(t *testing.T, ds *Datastore) {
pol2, err := ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: "policy2"})
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, h1, map[uint]*bool{pol1.ID: ptr.Bool(false), pol2.ID: ptr.Bool(true)}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, h1, map[uint]*bool{pol1.ID: new(false), pol2.ID: new(true)}, time.Now(), false, nil)
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, h2, map[uint]*bool{pol1.ID: ptr.Bool(false), pol2.ID: ptr.Bool(false)}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, h2, map[uint]*bool{pol1.ID: new(false), pol2.ID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
batch, err := ds.OutdatedAutomationBatch(ctx)
@ -3615,7 +3639,7 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) {
res := map[uint]*bool{
policy.ID: ptr.Bool(true),
}
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false, nil)
require.NoError(t, err)
}
@ -3671,7 +3695,7 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) {
res := map[uint]*bool{
policy.ID: result,
}
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false, nil)
require.NoError(t, err)
}
@ -3720,7 +3744,7 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) {
policy.ID: ptr.Bool(false),
policy2.ID: ptr.Bool(true),
}
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false, nil)
require.NoError(t, err)
}
for _, h := range teamHosts {
@ -3728,7 +3752,7 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) {
policy.ID: ptr.Bool(false),
policy2.ID: ptr.Bool(true),
}
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false, nil)
require.NoError(t, err)
}
@ -4101,25 +4125,25 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) {
err = ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{
team1Policy1.ID: ptr.Bool(true),
team1Policy2.ID: ptr.Bool(false),
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{
team2Policy1.ID: ptr.Bool(false),
team2Policy2.ID: ptr.Bool(true),
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, host3, map[uint]*bool{
team2Policy1.ID: ptr.Bool(true),
team2Policy2.ID: ptr.Bool(true),
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, host5, map[uint]*bool{
team1Policy1.ID: ptr.Bool(false),
team1Policy2.ID: ptr.Bool(false),
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
team1Policies, err := ds.GetCalendarPolicies(ctx, team1.ID)
@ -4187,7 +4211,7 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) {
err = ds.RecordPolicyQueryExecutions(ctx, host4, map[uint]*bool{
team1Policy1.ID: ptr.Bool(false),
team1Policy2.ID: ptr.Bool(false),
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
hostsTeam1, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team1.ID, []uint{team1Policies[0].ID}, nil)
@ -4271,7 +4295,7 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) {
ctx, host2, map[uint]*bool{
team2Policy1.ID: nil,
team2Policy2.ID: nil,
}, time.Now(), false,
}, time.Now(), false, nil,
)
require.NoError(t, err)
@ -4301,7 +4325,7 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) {
err = ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{
team2Policy1.ID: ptr.Bool(true),
team2Policy2.ID: ptr.Bool(true),
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
hostsTeam2, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team2.ID, []uint{team2Policies[0].ID, team2Policies[1].ID}, nil)
@ -4396,7 +4420,7 @@ func testGetTeamHostsPolicyMembershipsEmailPriority(t *testing.T, ds *Datastore)
})
require.NoError(t, err)
// Make the host fail the calendar policy so it always appears in results.
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{calendarPolicy.ID: ptr.Bool(false)}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{calendarPolicy.ID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
return h
}
@ -5158,7 +5182,7 @@ func testApplyPolicySpecWithInstallers(t *testing.T, ds *Datastore) {
err = ds.RecordPolicyQueryExecutions(ctx, host1Team1, map[uint]*bool{
policy1Team1.ID: ptr.Bool(false),
vppPolicy1Team1.ID: ptr.Bool(false),
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
err = ds.UpdateHostPolicyCounts(ctx)
require.NoError(t, err)
@ -5417,7 +5441,7 @@ func testApplyPolicySpecWithInstallers(t *testing.T, ds *Datastore) {
err = ds.RecordPolicyQueryExecutions(ctx, host1Team1, map[uint]*bool{
policy1Team1.ID: ptr.Bool(false),
vppPolicy1Team1.ID: ptr.Bool(false),
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
err = ds.UpdateHostPolicyCounts(ctx)
require.NoError(t, err)
@ -5604,7 +5628,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) {
globalPolicy2.ID: ptr.Bool(false),
policy0NoTeam.ID: ptr.Bool(true),
policy3NoTeam.ID: ptr.Bool(false),
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
// Results for host1Team1
@ -5612,7 +5636,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) {
globalPolicy1.ID: ptr.Bool(true),
globalPolicy2.ID: nil, // failed to execute, e.g. typo on SQL.
policy1Team1.ID: ptr.Bool(true),
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
// Results for host2Team1
@ -5620,7 +5644,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) {
globalPolicy1.ID: ptr.Bool(false),
globalPolicy2.ID: ptr.Bool(true),
policy1Team1.ID: ptr.Bool(false),
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
// Results for host3Team2
@ -5628,7 +5652,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) {
globalPolicy1.ID: ptr.Bool(true),
policy2Team2.ID: ptr.Bool(true),
policy4Team2.ID: ptr.Bool(false),
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
// Results for host5NoTeam
@ -5637,7 +5661,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) {
globalPolicy2.ID: ptr.Bool(false),
policy0NoTeam.ID: ptr.Bool(false),
policy3NoTeam.ID: ptr.Bool(false),
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
err = ds.UpdateHostPolicyCounts(ctx)
@ -6155,7 +6179,7 @@ func testClearAutoInstallPolicyStatusForHost(t *testing.T, ds *Datastore) {
policy1.ID: ptr.Bool(true),
policy2.ID: ptr.Bool(false), // software isn't installed on host, so Fleet should install it
policy3.ID: ptr.Bool(false), // software isn't installed on host, so Fleet should install it
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
hostPolicies, err := ds.ListPoliciesForHost(ctx, host)
@ -6378,7 +6402,7 @@ func testPolicyLabelMembershipCleanup(t *testing.T, ds *Datastore) {
// Record policy results for all hosts
for _, h := range []*fleet.Host{hostNoLabels, hostLabel1, hostLabel2, hostLabelBoth} {
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil)
require.NoError(t, err)
}
@ -6408,7 +6432,7 @@ func testPolicyLabelMembershipCleanup(t *testing.T, ds *Datastore) {
// Re-record membership for all hosts to test exclude labels
for _, h := range []*fleet.Host{hostNoLabels, hostLabel1, hostLabel2, hostLabelBoth} {
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil)
require.NoError(t, err)
}
wantHostsByPol[policy.Name] = []uint{hostNoLabels.ID, hostLabel1.ID, hostLabel2.ID, hostLabelBoth.ID}
@ -6426,7 +6450,7 @@ func testPolicyLabelMembershipCleanup(t *testing.T, ds *Datastore) {
// Test ApplyPolicySpecs with label changes
// First, re-record membership for all hosts
for _, h := range []*fleet.Host{hostNoLabels, hostLabel1, hostLabel2, hostLabelBoth} {
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil)
require.NoError(t, err)
}
wantHostsByPol[policy.Name] = []uint{hostNoLabels.ID, hostLabel1.ID, hostLabel2.ID, hostLabelBoth.ID}
@ -6464,7 +6488,7 @@ func testPolicyLabelMembershipCleanup(t *testing.T, ds *Datastore) {
// Record membership for all hosts with label1
for _, h := range []*fleet.Host{hostLabel1, hostLabelBoth, hostWinLabel1, hostMacLabel1} {
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy2.ID: ptr.Bool(true)}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy2.ID: new(true)}, time.Now(), false, nil)
require.NoError(t, err)
}
@ -6551,8 +6575,8 @@ func testDeletePolicyWithSoftwareActivatesNextActivity(t *testing.T, ds *Datasto
require.NoError(t, err)
// record a failing policy for both hosts, would enqueue the install
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostNoTm, map[uint]*bool{policyNoTm.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostTm, map[uint]*bool{policyTm.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostNoTm, map[uint]*bool{policyNoTm.ID: new(false)}, time.Now(), false, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostTm, map[uint]*bool{policyTm.ID: new(false)}, time.Now(), false, nil))
// simulate the work of "processSoftwareForNewlyFailingPolicies"
installUUIDNoTm, err := ds.InsertSoftwareInstallRequest(ctx, hostNoTm.ID, installerIDNoTm,
@ -6645,8 +6669,8 @@ func testDeletePolicyWithScriptActivatesNextActivity(t *testing.T, ds *Datastore
// record a failing policy for both hosts, would enqueue the associated
// scripts
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostNoTm, map[uint]*bool{policyNoTm.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostTm, map[uint]*bool{policyTm.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostNoTm, map[uint]*bool{policyNoTm.ID: new(false)}, time.Now(), false, nil))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostTm, map[uint]*bool{policyTm.ID: new(false)}, time.Now(), false, nil))
// simulate the work of "processScriptsForNewlyFailingPolicies"
hsrPolicyNoTm, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
@ -6713,7 +6737,7 @@ func testSimultaneousSavePolicy(t *testing.T, ds *Datastore) {
for _, policy := range policies {
host1Results[policy.ID] = ptr.Bool(true)
}
err = ds.RecordPolicyQueryExecutions(ctx, host1, host1Results, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, host1, host1Results, time.Now(), false, nil)
require.NoError(t, err)
// Run simultaneous
@ -6754,7 +6778,7 @@ func testIsPolicyFailing(t *testing.T, ds *Datastore) {
// Exists with passes = NULL
// failing
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: nil}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: nil}, time.Now(), false, nil)
require.NoError(t, err)
isFailing, err = ds.IsPolicyFailing(ctx, policy.ID, host.ID)
@ -6763,7 +6787,7 @@ func testIsPolicyFailing(t *testing.T, ds *Datastore) {
// exists with passes = false
// failing
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(false)}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
isFailing, err = ds.IsPolicyFailing(ctx, policy.ID, host.ID)
@ -6772,7 +6796,7 @@ func testIsPolicyFailing(t *testing.T, ds *Datastore) {
// exists with passes = true
// Not failing
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil)
require.NoError(t, err)
isFailing, err = ds.IsPolicyFailing(ctx, policy.ID, host.ID)
@ -6821,7 +6845,7 @@ func testResetAttemptsOnFailingToPassingSync(t *testing.T, ds *Datastore) {
require.NoError(t, err)
// p1 will be failing
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: new(false)}, time.Now(), false, nil))
// Create rows with attempt_number > 0 and attempt_number IS NULL (pending)
// p1 - completed attempt
@ -6852,7 +6876,7 @@ func testResetAttemptsOnFailingToPassingSync(t *testing.T, ds *Datastore) {
require.NoError(t, err)
// p1 is now passing
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: new(true), p2.ID: new(true)}, time.Now(), false, nil)
require.NoError(t, err)
// p1 rows should be reset to 0 (both completed and pending)
@ -6903,7 +6927,7 @@ func testResetAttemptsOnFailingToPassingAsync(t *testing.T, ds *Datastore) {
require.NoError(t, err)
// p1 is failing
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: new(false)}, time.Now(), false, nil))
// Create rows with attempt_number > 0 and attempt_number IS NULL (pending)
// p1 - completed attempt
@ -7139,7 +7163,7 @@ func testBatchedPolicyMembershipCleanup(t *testing.T, ds *Datastore) {
// Record failing results for all hosts so they all have policy_membership rows and host_issues entries.
for _, h := range hosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
}
@ -7228,7 +7252,7 @@ func testBatchedPolicyMembershipCleanupOnPolicyUpdate(t *testing.T, ds *Datastor
// Record results for all hosts (simulating results arriving before platform filter applied).
allHosts := append([]*fleet.Host{winHost}, linuxHosts...)
for _, h := range allHosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
}
@ -7293,7 +7317,7 @@ func testBatchedPolicyMembershipCleanupOnPolicyUpdate(t *testing.T, ds *Datastor
// Record policy results for all label-test hosts so policy_membership is populated.
labelHosts := append([]*fleet.Host{lblHost}, nonLblHosts...)
for _, h := range labelHosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{lblPol.ID: ptr.Bool(false)}, time.Now(), false)
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{lblPol.ID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
}
@ -7356,7 +7380,7 @@ func testApplyPolicySpecsNeedsFullMembershipCleanupFlag(t *testing.T, ds *Datast
hosts[i] = h
}
for _, h := range hosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
}
@ -7416,7 +7440,7 @@ func testCleanupPolicyMembershipCrashRecovery(t *testing.T, ds *Datastore) {
recordResults := func(t *testing.T, hosts []*fleet.Host, polID uint) {
t.Helper()
for _, h := range hosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{polID: ptr.Bool(false)}, time.Now(), false)
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{polID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
}
}

View file

@ -1092,7 +1092,10 @@ type Datastore interface {
// RecordPolicyQueryExecutions records the execution results of the policies for the given host.
// Even if `results` is empty, the host's `policy_updated_at` will be updated.
RecordPolicyQueryExecutions(ctx context.Context, host *Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error
// If newlyPassingPolicyIDs is non-nil, it contains the IDs of policies that flipped from failing to passing
// and is used directly instead of calling FlippingPoliciesForHost internally. This allows callers that have
// already computed flipping policies to avoid a redundant database query.
RecordPolicyQueryExecutions(ctx context.Context, host *Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool, newlyPassingPolicyIDs []uint) error
// RecordLabelQueryExecutions saves the results of label queries. The results map is a map of label id -> whether or
// not the label matches. The time parameter is the timestamp to save with the query execution.

View file

@ -781,7 +781,7 @@ type UpdateHostRefetchCriticalQueriesUntilFunc func(ctx context.Context, hostID
type FlippingPoliciesForHostFunc func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error)
type RecordPolicyQueryExecutionsFunc func(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error
type RecordPolicyQueryExecutionsFunc func(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool, newlyPassingPolicyIDs []uint) error
type RecordLabelQueryExecutionsFunc func(ctx context.Context, host *fleet.Host, results map[uint]*bool, t time.Time, deferredSaveHost bool) error
@ -7260,11 +7260,11 @@ func (s *DataStore) FlippingPoliciesForHost(ctx context.Context, hostID uint, in
return s.FlippingPoliciesForHostFunc(ctx, hostID, incomingResults)
}
func (s *DataStore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error {
func (s *DataStore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool, newlyPassingPolicyIDs []uint) error {
s.mu.Lock()
s.RecordPolicyQueryExecutionsFuncInvoked = true
s.mu.Unlock()
return s.RecordPolicyQueryExecutionsFunc(ctx, host, results, updated, deferredSaveHost)
return s.RecordPolicyQueryExecutionsFunc(ctx, host, results, updated, deferredSaveHost, newlyPassingPolicyIDs)
}
func (s *DataStore) RecordLabelQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, t time.Time, deferredSaveHost bool) error {

View file

@ -28,11 +28,11 @@ const (
// redis list will be LTRIM'd if there are more policy IDs than this.
var maxRedisPolicyResultsPerHost = 1000
func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error {
func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool, newlyPassingPolicyIDs []uint) error {
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
if !cfg.Enabled {
host.PolicyUpdatedAt = ts
return t.datastore.RecordPolicyQueryExecutions(ctx, host, results, ts, deferred)
return t.datastore.RecordPolicyQueryExecutions(ctx, host, results, ts, deferred, newlyPassingPolicyIDs)
}
keyList := fmt.Sprintf(policyPassHostKey, host.ID)

View file

@ -321,7 +321,7 @@ func testRecordPolicyQueryExecutionsSync(t *testing.T, ds *mock.Store, pool flee
policyReportedAt := task.GetHostPolicyReportedAt(ctx, host)
require.True(t, policyReportedAt.Equal(lastYear))
err := task.RecordPolicyQueryExecutions(ctx, host, results, now, false)
err := task.RecordPolicyQueryExecutions(ctx, host, results, now, false, nil)
require.NoError(t, err)
require.True(t, ds.RecordPolicyQueryExecutionsFuncInvoked)
ds.RecordPolicyQueryExecutionsFuncInvoked = false
@ -373,7 +373,7 @@ func testRecordPolicyQueryExecutionsAsync(t *testing.T, ds *mock.Store, pool fle
policyReportedAt := task.GetHostPolicyReportedAt(ctx, host)
require.True(t, policyReportedAt.Equal(lastYear))
err := task.RecordPolicyQueryExecutions(ctx, host, results, now, false)
err := task.RecordPolicyQueryExecutions(ctx, host, results, now, false, nil)
require.NoError(t, err)
require.False(t, ds.RecordPolicyQueryExecutionsFuncInvoked)
@ -435,7 +435,7 @@ func testRecordPolicyQueryExecutionsNoPoliciesSync(t *testing.T, ds *mock.Store,
policyReportedAt := task.GetHostPolicyReportedAt(ctx, host)
require.True(t, policyReportedAt.Equal(lastYear))
err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false)
err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false, nil)
require.NoError(t, err)
require.True(t, ds.RecordPolicyQueryExecutionsFuncInvoked)
ds.RecordPolicyQueryExecutionsFuncInvoked = false
@ -485,7 +485,7 @@ func testRecordPolicyQueryExecutionsNoPoliciesAsync(t *testing.T, ds *mock.Store
policyReportedAt := task.GetHostPolicyReportedAt(ctx, host)
require.True(t, policyReportedAt.Equal(lastYear))
err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false)
err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false, nil)
require.NoError(t, err)
require.False(t, ds.RecordPolicyQueryExecutionsFuncInvoked)

View file

@ -88,7 +88,7 @@ func TestRecord(t *testing.T) {
ds.AsyncBatchUpdateLabelTimestampFunc = func(ctx context.Context, ids []uint, ts time.Time) error {
return nil
}
ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error {
ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool, newlyPassingPolicyIDs []uint) error {
return nil
}
ds.AsyncBatchInsertPolicyMembershipFunc = func(ctx context.Context, batch []fleet.PolicyMembershipResult) error {

View file

@ -1223,8 +1223,8 @@ func (s *integrationTestSuite) TestGlobalPolicies() {
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
require.Len(t, listHostsResp.Hosts, 0)
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: new(true)}, time.Now(), false, nil))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false, nil))
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID)
listHostsResp = listHostsResponse{}
@ -1844,7 +1844,7 @@ func (s *integrationTestSuite) TestListHosts() {
require.NoError(
t,
s.ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{globalPolicy0.ID: ptr.Bool(false)}, time.Now(), false),
s.ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{globalPolicy0.ID: new(false)}, time.Now(), false, nil),
)
resp = listHostsResponse{}
@ -2109,7 +2109,7 @@ func (s *integrationTestSuite) TestListHosts() {
for _, host := range hosts {
// All hosts pass the globalPolicy1
err := s.ds.RecordPolicyQueryExecutions(
context.Background(), host, map[uint]*bool{globalPolicy1.ID: ptr.Bool(true)}, time.Now(), false,
context.Background(), host, map[uint]*bool{globalPolicy1.ID: new(true)}, time.Now(), false, nil,
)
require.NoError(t, err)
}
@ -2927,8 +2927,8 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
require.Len(t, listHostsResp.Hosts, 0)
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: new(true)}, time.Now(), false, nil))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false, nil))
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID)
listHostsResp = listHostsResponse{}
@ -2978,12 +2978,12 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
// Record query executions
require.NoError(
t, s.ds.RecordPolicyQueryExecutions(
context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false,
context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: new(true)}, time.Now(), false, nil,
),
)
require.NoError(
t, s.ds.RecordPolicyQueryExecutions(
context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false,
context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false, nil,
),
)
// Update policy stats
@ -3169,8 +3169,8 @@ func (s *integrationTestSuite) TestTeamPoliciesProprietary() {
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
require.Len(t, listHostsResp.Hosts, 0)
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: new(true)}, time.Now(), false, nil))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false, nil))
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?team_id=%d&policy_id=%d&policy_response=passing", team1.ID, policiesResponse.Policies[0].ID)
listHostsResp = listHostsResponse{}
@ -3384,6 +3384,7 @@ func (s *integrationTestSuite) TestHostDetailsPolicies() {
map[uint]*bool{gpResp.Policy.ID: ptr.Bool(true)},
time.Now(),
false,
nil,
)
require.NoError(t, err)
@ -5417,7 +5418,7 @@ func (s *integrationTestSuite) TestListHostsByLabel() {
require.NotNil(t, gpResp.Policy)
require.NoError(
t,
s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{gpResp.Policy.ID: ptr.Bool(false)}, time.Now(), false),
s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{gpResp.Policy.ID: new(false)}, time.Now(), false, nil),
)
// Add MDM info
@ -9311,7 +9312,7 @@ func (s *integrationTestSuite) TestReenrollHostCleansPolicies() {
// create a policy and make the host fail it
pol, err := s.ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: t.Name(), Query: "SELECT 1", Platform: host.FleetPlatform()})
require.NoError(t, err)
err = s.ds.RecordPolicyQueryExecutions(ctx, &fleet.Host{ID: host.ID}, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
err = s.ds.RecordPolicyQueryExecutions(ctx, &fleet.Host{ID: host.ID}, map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
// refetch the host details
@ -10052,7 +10053,7 @@ func (s *integrationTestSuite) TestHostsReportDownload() {
// create a policy and make host[1] fail that policy
pol, err := s.ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: t.Name(), Query: "SELECT 1"})
require.NoError(t, err)
err = s.ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
err = s.ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
// create some device mappings for host[2]
@ -12474,20 +12475,20 @@ func (s *integrationTestSuite) TestHostsReportWithPolicyResults() {
for i, host := range hosts {
// All hosts pass the globalPolicy0
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy0.ID: ptr.Bool(true)}, time.Now(), false)
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy0.ID: new(true)}, time.Now(), false, nil)
require.NoError(t, err)
if i%2 == 0 {
// Half of the hosts pass the globalPolicy1 and fail the globalPolicy2
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: ptr.Bool(true)}, time.Now(), false)
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: new(true)}, time.Now(), false, nil)
require.NoError(t, err)
err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: ptr.Bool(false)}, time.Now(), false)
err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
} else {
// Half of the hosts pass the globalPolicy2 and fail the globalPolicy1
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: ptr.Bool(false)}, time.Now(), false)
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: ptr.Bool(true)}, time.Now(), false)
err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: new(true)}, time.Now(), false, nil)
require.NoError(t, err)
}
}
@ -13593,8 +13594,8 @@ func (s *integrationTestSuite) TestHostHealth() {
})
require.NoError(t, err)
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingPolicy.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingPolicy.ID: new(false)}, time.Now(), false, nil))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingPolicy.ID: new(true)}, time.Now(), false, nil))
require.NoError(t, s.ds.SetOrUpdateHostDisksEncryption(context.Background(), host.ID, true))

View file

@ -3017,7 +3017,7 @@ func (s *integrationEnterpriseTestSuite) TestNoTeamFailingPolicyWebhookTrigger()
noTeamPol1.ID: ptr.Bool(false), // Fails and is in webhook config
noTeamPol2.ID: ptr.Bool(false), // Fails and is in webhook config
noTeamPol3.ID: ptr.Bool(false), // Fails but NOT in webhook config
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
// Initially, OutdatedAutomationBatch should be empty (policies haven't been triggered for automation yet)
@ -4055,7 +4055,7 @@ func (s *integrationEnterpriseTestSuite) TestListDevicePolicies() {
// add a policy execution
require.NoError(t, s.ds.RecordPolicyQueryExecutions(ctx, host,
map[uint]*bool{gpResp.Policy.ID: ptr.Bool(false)}, time.Now(), false))
map[uint]*bool{gpResp.Policy.ID: new(false)}, time.Now(), false, nil))
// add a policy to team
oldToken := s.token
@ -5303,7 +5303,7 @@ func (s *integrationEnterpriseTestSuite) TestListHosts() {
require.NoError(
t, s.ds.RecordPolicyQueryExecutions(
ctx, host1,
map[uint]*bool{gpResp.Policy.ID: ptr.Bool(false)}, time.Now(), false,
map[uint]*bool{gpResp.Policy.ID: new(false)}, time.Now(), false, nil,
),
)
@ -5568,10 +5568,10 @@ func (s *integrationEnterpriseTestSuite) TestHostHealth() {
})
require.NoError(t, err)
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingGlobalPolicy.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingGlobalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingTeamPolicy.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingTeamPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingGlobalPolicy.ID: new(false)}, time.Now(), false, nil))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingGlobalPolicy.ID: new(true)}, time.Now(), false, nil))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingTeamPolicy.ID: new(false)}, time.Now(), false, nil))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingTeamPolicy.ID: new(true)}, time.Now(), false, nil))
hh := getHostHealthResponse{}
s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/hosts/%d/health", host.ID), nil, http.StatusOK, &hh)
@ -6288,7 +6288,7 @@ func (s *integrationEnterpriseTestSuite) TestResetAutomation() {
createPol1.Policy.ID: ptr.Bool(false),
createPol2.Policy.ID: ptr.Bool(false),
createPol3.Policy.ID: ptr.Bool(false), // This policy is not activated for automation in config.
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(s.T(), err)
pfs, err := s.ds.OutdatedAutomationBatch(ctx)
@ -7417,7 +7417,7 @@ func (s *integrationEnterpriseTestSuite) TestDesktopEndpointWithInvalidPolicy()
Critical: false,
})
require.NoError(t, err)
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{policy.ID: nil}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{policy.ID: nil}, time.Now(), false, nil))
// Any 'invalid' policies should be ignored.
desktopRes := fleetDesktopResponse{}
@ -25216,7 +25216,7 @@ FqU+KJOed6qlzj7qy+u5l6CQeajLGdjUxFlFyw==
require.NoError(t, err)
// Record a failing result for this policy on the host
err = s.ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(false)}, time.Now(), false)
err = s.ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
// Bypass should fail with 400 Bad Request
@ -25261,7 +25261,7 @@ FqU+KJOed6qlzj7qy+u5l6CQeajLGdjUxFlFyw==
err = s.ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{
caPolicy.ID: ptr.Bool(true), // passing
nonCAPolicy.ID: ptr.Bool(false), // failing
}, time.Now(), false)
}, time.Now(), false, nil)
require.NoError(t, err)
// Bypass must succeed: the only failing policy is not CA-enabled

View file

@ -924,7 +924,7 @@ func TestSoftwareInstallReplicaLag(t *testing.T) {
opts.RunReplication("software_installers", "software_titles")
// Mark policy as failing for the host
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(false)}, time.Now(), false)
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(false)}, time.Now(), false, nil)
require.NoError(t, err)
opts.RunReplication("policy_membership")

View file

@ -1209,11 +1209,27 @@ func (svc *Service) SubmitDistributedQueryResults(
}
if len(policyResults) > 0 {
// Compute flipping policies once for all consumers. This replaces up to 5 individual calls to
// FlippingPoliciesForHost with a single database query.
newFailing, newPassing, err := svc.ds.FlippingPoliciesForHost(ctx, host.ID, policyResults)
if err != nil {
logging.WithErr(ctx, err)
}
// Ensure newPassing is non-nil so RecordPolicyQueryExecutions can distinguish "pre-computed with zero results"
// from "not pre-computed" (nil means compute it yourself).
if newPassing == nil {
newPassing = []uint{}
}
newFailingSet := make(map[uint]struct{}, len(newFailing))
for _, id := range newFailing {
newFailingSet[id] = struct{}{}
}
if err := processCalendarPolicies(ctx, svc.ds, ac, host, policyResults, svc.logger); err != nil {
logging.WithErr(ctx, err)
}
if err := svc.processScriptsForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, host.OrbitNodeKey, host.ScriptsEnabled, policyResults); err != nil {
if err := svc.processScriptsForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, host.OrbitNodeKey, host.ScriptsEnabled, policyResults, newFailingSet); err != nil {
logging.WithErr(ctx, err)
}
@ -1226,18 +1242,18 @@ func (svc *Service) SubmitDistributedQueryResults(
if host.Platform == "darwin" && svc.EnterpriseOverrides != nil {
// NOTE: if the installers for the policies here are not scoped to the host via labels, we update the policy status here to stop it from showing up as "failed" in the
// host details.
if err := svc.processVPPForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, policyResults); err != nil {
if err := svc.processVPPForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, policyResults, newFailingSet); err != nil {
logging.WithErr(ctx, err)
}
}
// NOTE: if the installers for the policies here are not scoped to the host via labels, we update the policy status here to stop it from showing up as "failed" in the
// host details.
if err := svc.processSoftwareForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, host.OrbitNodeKey, policyResults); err != nil {
if err := svc.processSoftwareForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, host.OrbitNodeKey, policyResults, newFailingSet); err != nil {
logging.WithErr(ctx, err)
}
// filter policy results for webhooks
// Filter policy results for webhooks using pre-computed flipping sets.
var policyIDs []uint
if globalPolicyAutomationsEnabled(ac.WebhookSettings, ac.Integrations) {
policyIDs = append(policyIDs, ac.WebhookSettings.FailingPoliciesWebhook.PolicyIDs...)
@ -1256,12 +1272,13 @@ func (svc *Service) SubmitDistributedQueryResults(
filteredResults := filterPolicyResults(policyResults, policyIDs)
if len(filteredResults) > 0 {
if failingPolicies, passingPolicies, err := svc.ds.FlippingPoliciesForHost(ctx, host.ID, filteredResults); err != nil {
logging.WithErr(ctx, err)
} else {
// Filter the pre-computed flipping results to only webhook-enabled policies.
webhookFailing := filterByPolicyIDs(newFailing, filteredResults)
webhookPassing := filterByPolicyIDs(newPassing, filteredResults)
if len(webhookFailing) > 0 || len(webhookPassing) > 0 {
// Register the flipped policies on a goroutine to not block the hosts on redis requests.
go func() {
if err := svc.registerFlippedPolicies(ctx, host.ID, host.Hostname, host.DisplayName(), failingPolicies, passingPolicies); err != nil {
if err := svc.registerFlippedPolicies(ctx, host.ID, host.Hostname, host.DisplayName(), webhookFailing, webhookPassing); err != nil {
logging.WithErr(ctx, err)
}
}()
@ -1275,12 +1292,12 @@ func (svc *Service) SubmitDistributedQueryResults(
// maybe we should impose restrictions between async collection interval
// and policy update interval?
if err := svc.task.RecordPolicyQueryExecutions(ctx, host, policyResults, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil {
if err := svc.task.RecordPolicyQueryExecutions(ctx, host, policyResults, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost, newPassing); err != nil {
logging.WithErr(ctx, err)
}
} else if hostWithoutPolicies {
// RecordPolicyQueryExecutions called with results=nil will still update the host's policy_updated_at column.
if err := svc.task.RecordPolicyQueryExecutions(ctx, host, nil, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil {
if err := svc.task.RecordPolicyQueryExecutions(ctx, host, nil, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost, []uint{}); err != nil {
logging.WithErr(ctx, err)
}
}
@ -1945,6 +1962,18 @@ func filterPolicyResults(incoming map[uint]*bool, webhookPolicies []uint) map[ui
return filtered
}
// filterByPolicyIDs returns only the policy IDs from ids that are present in allowedResults and have a non-nil result
// (i.e., the policy actually executed). This matches the behavior of FlippingPoliciesForHost which ignores nil results.
func filterByPolicyIDs(ids []uint, allowedResults map[uint]*bool) []uint {
var filtered []uint
for _, id := range ids {
if val, ok := allowedResults[id]; ok && val != nil {
filtered = append(filtered, id)
}
}
return filtered
}
func (svc *Service) registerFlippedPolicies(ctx context.Context, hostID uint, hostname, displayName string, newFailing, newPassing []uint) error {
host := fleet.PolicySetHost{
ID: hostID,
@ -1971,6 +2000,7 @@ func (svc *Service) processSoftwareForNewlyFailingPolicies(
hostPlatform string,
hostOrbitNodeKey *string,
incomingPolicyResults map[uint]*bool,
newFailingSet map[uint]struct{},
) error {
if hostOrbitNodeKey == nil || *hostOrbitNodeKey == "" {
// We do not want to queue software installations on vanilla osquery hosts.
@ -1986,15 +2016,13 @@ func (svc *Service) processSoftwareForNewlyFailingPolicies(
// Filter out results that are not failures (we are only interested on failing policies,
// we don't care about passing policies or policies that failed to execute).
incomingFailingPolicies := make(map[uint]*bool)
var incomingFailingPoliciesIDs []uint
for policyID, policyResult := range incomingPolicyResults {
if policyResult != nil && !*policyResult {
incomingFailingPolicies[policyID] = policyResult
incomingFailingPoliciesIDs = append(incomingFailingPoliciesIDs, policyID)
}
}
if len(incomingFailingPolicies) == 0 {
if len(incomingFailingPoliciesIDs) == 0 {
return nil
}
@ -2007,44 +2035,16 @@ func (svc *Service) processSoftwareForNewlyFailingPolicies(
return nil
}
// Filter out results of policies that are not associated to installers.
policiesWithInstallersMap := make(map[uint]fleet.PolicySoftwareInstallerData)
for _, policyWithInstaller := range policiesWithInstaller {
policiesWithInstallersMap[policyWithInstaller.ID] = policyWithInstaller
}
policyResultsOfPoliciesWithInstallers := make(map[uint]*bool)
for policyID, passes := range incomingFailingPolicies {
if _, ok := policiesWithInstallersMap[policyID]; !ok {
continue
}
policyResultsOfPoliciesWithInstallers[policyID] = passes
}
if len(policyResultsOfPoliciesWithInstallers) == 0 {
return nil
}
// Get the policies associated with installers that are flipping from passing to failing on this host.
policyIDsOfNewlyFailingPoliciesWithInstallers, _, err := svc.ds.FlippingPoliciesForHost(
ctx, hostID, policyResultsOfPoliciesWithInstallers,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "failed to get flipping policies for host")
}
if len(policyIDsOfNewlyFailingPoliciesWithInstallers) == 0 {
return nil
}
policyIDsOfNewlyFailingPoliciesWithInstallersSet := make(map[uint]struct{})
for _, policyID := range policyIDsOfNewlyFailingPoliciesWithInstallers {
policyIDsOfNewlyFailingPoliciesWithInstallersSet[policyID] = struct{}{}
}
// Finally filter out policies with installers that are not newly failing.
// Filter to policies with installers that are newly failing, using the pre-computed set.
var failingPoliciesWithInstaller []fleet.PolicySoftwareInstallerData
for _, policyWithInstaller := range policiesWithInstaller {
if _, ok := policyIDsOfNewlyFailingPoliciesWithInstallersSet[policyWithInstaller.ID]; ok {
if _, ok := newFailingSet[policyWithInstaller.ID]; ok {
failingPoliciesWithInstaller = append(failingPoliciesWithInstaller, policyWithInstaller)
}
}
if len(failingPoliciesWithInstaller) == 0 {
return nil
}
for _, failingPolicyWithInstaller := range failingPoliciesWithInstaller {
policyID := failingPolicyWithInstaller.ID
@ -2119,6 +2119,7 @@ func (svc *Service) processVPPForNewlyFailingPolicies(
hostTeamID *uint,
hostPlatform string,
incomingPolicyResults map[uint]*bool,
newFailingSet map[uint]struct{},
) error {
var policyTeamID uint
if hostTeamID == nil {
@ -2129,15 +2130,13 @@ func (svc *Service) processVPPForNewlyFailingPolicies(
// Filter out results that are not failures (we are only interested on failing policies,
// we don't care about passing policies or policies that failed to execute).
incomingFailingPolicies := make(map[uint]*bool)
var incomingFailingPoliciesIDs []uint
for policyID, policyResult := range incomingPolicyResults {
if policyResult != nil && !*policyResult {
incomingFailingPolicies[policyID] = policyResult
incomingFailingPoliciesIDs = append(incomingFailingPoliciesIDs, policyID)
}
}
if len(incomingFailingPolicies) == 0 {
if len(incomingFailingPoliciesIDs) == 0 {
return nil
}
@ -2150,45 +2149,13 @@ func (svc *Service) processVPPForNewlyFailingPolicies(
return nil
}
// Filter out results of policies that are not associated to VPP apps.
policiesWithVPPMap := make(map[uint]fleet.PolicyVPPData)
for _, policyWithVPP := range policiesWithVPP {
policiesWithVPPMap[policyWithVPP.ID] = policyWithVPP
}
policyResultsOfPoliciesWithVPP := make(map[uint]*bool)
for policyID, passes := range incomingFailingPolicies {
if _, ok := policiesWithVPPMap[policyID]; !ok {
continue
}
policyResultsOfPoliciesWithVPP[policyID] = passes
}
if len(policyResultsOfPoliciesWithVPP) == 0 {
return nil
}
// Get the policies associated with VPP apps that are flipping from passing to failing on this host.
policyIDsOfNewlyFailingPoliciesWithVPP, _, err := svc.ds.FlippingPoliciesForHost(
ctx, hostID, policyResultsOfPoliciesWithVPP,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "failed to get flipping policies for host")
}
if len(policyIDsOfNewlyFailingPoliciesWithVPP) == 0 {
return nil
}
policyIDsOfNewlyFailingPoliciesWithVPPSet := make(map[uint]struct{})
for _, policyID := range policyIDsOfNewlyFailingPoliciesWithVPP {
policyIDsOfNewlyFailingPoliciesWithVPPSet[policyID] = struct{}{}
}
// Finally filter out policies with VPP apps that are not newly failing.
// Filter to policies with VPP apps that are newly failing, using the pre-computed set.
var failingPoliciesWithVPP []fleet.PolicyVPPData
for _, policyWithVPP := range policiesWithVPP {
if _, ok := policyIDsOfNewlyFailingPoliciesWithVPPSet[policyWithVPP.ID]; ok {
if _, ok := newFailingSet[policyWithVPP.ID]; ok {
failingPoliciesWithVPP = append(failingPoliciesWithVPP, policyWithVPP)
}
}
if len(failingPoliciesWithVPP) == 0 {
return nil
}
@ -2269,6 +2236,7 @@ func (svc *Service) processScriptsForNewlyFailingPolicies(
hostOrbitNodeKey *string,
hostScriptsEnabled *bool,
incomingPolicyResults map[uint]*bool,
newFailingSet map[uint]struct{},
) error {
if hostOrbitNodeKey == nil || *hostOrbitNodeKey == "" {
return nil // vanilla osquery hosts can't run scripts
@ -2297,15 +2265,13 @@ func (svc *Service) processScriptsForNewlyFailingPolicies(
// Filter out results that are not failures (we are only interested on failing policies,
// we don't care about passing policies or policies that failed to execute).
incomingFailingPolicies := make(map[uint]*bool)
var incomingFailingPoliciesIDs []uint
for policyID, policyResult := range incomingPolicyResults {
if policyResult != nil && !*policyResult {
incomingFailingPolicies[policyID] = policyResult
incomingFailingPoliciesIDs = append(incomingFailingPoliciesIDs, policyID)
}
}
if len(incomingFailingPolicies) == 0 {
if len(incomingFailingPoliciesIDs) == 0 {
return nil
}
@ -2318,44 +2284,16 @@ func (svc *Service) processScriptsForNewlyFailingPolicies(
return nil
}
// Filter out results of policies that are not associated to scripts.
policiesWithScriptsMap := make(map[uint]fleet.PolicyScriptData)
for _, policyWithScript := range policiesWithScript {
policiesWithScriptsMap[policyWithScript.ID] = policyWithScript
}
policyResultsOfPoliciesWithScripts := make(map[uint]*bool)
for policyID, passes := range incomingFailingPolicies {
if _, ok := policiesWithScriptsMap[policyID]; !ok {
continue
}
policyResultsOfPoliciesWithScripts[policyID] = passes
}
if len(policyResultsOfPoliciesWithScripts) == 0 {
return nil
}
// Get the policies associated with scripts that are flipping from passing to failing on this host.
policyIDsOfNewlyFailingPoliciesWithScripts, _, err := svc.ds.FlippingPoliciesForHost(
ctx, hostID, policyResultsOfPoliciesWithScripts,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "failed to get flipping policies for host")
}
if len(policyIDsOfNewlyFailingPoliciesWithScripts) == 0 {
return nil
}
policyIDsOfNewlyFailingPoliciesWithScriptsSet := make(map[uint]struct{})
for _, policyID := range policyIDsOfNewlyFailingPoliciesWithScripts {
policyIDsOfNewlyFailingPoliciesWithScriptsSet[policyID] = struct{}{}
}
// Finally filter out policies with scripts that are not newly failing.
// Filter to policies with scripts that are newly failing, using the pre-computed set.
var failingPoliciesWithScript []fleet.PolicyScriptData
for _, policyWithScript := range policiesWithScript {
if _, ok := policyIDsOfNewlyFailingPoliciesWithScriptsSet[policyWithScript.ID]; ok {
if _, ok := newFailingSet[policyWithScript.ID]; ok {
failingPoliciesWithScript = append(failingPoliciesWithScript, policyWithScript)
}
}
if len(failingPoliciesWithScript) == 0 {
return nil
}
for _, failingPolicyWithScript := range failingPoliciesWithScript {
policyID := failingPolicyWithScript.ID

View file

@ -3350,7 +3350,7 @@ func TestPolicyQueries(t *testing.T) {
}
recordedResults := make(map[uint]*bool)
ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, gotHost *fleet.Host, results map[uint]*bool, updated time.Time,
deferred bool,
deferred bool, newlyPassingPolicyIDs []uint,
) error {
recordedResults = results
host = gotHost
@ -3661,7 +3661,7 @@ func TestPolicyWebhooks(t *testing.T) {
}
recordedResults := make(map[uint]*bool)
ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, gotHost *fleet.Host, results map[uint]*bool, updated time.Time,
deferred bool,
deferred bool, newlyPassingPolicyIDs []uint,
) error {
recordedResults = results
host = gotHost
@ -3696,12 +3696,30 @@ func TestPolicyWebhooks(t *testing.T) {
checkPolicyResults(queries)
// Track FlippingPoliciesForHost calls to verify the deduplication optimization: it should be called exactly once
// per SubmitDistributedQueryResults with ALL policy results, not multiple times with subsets.
var flippingCallCount int
var flippingIncomingResults map[uint]*bool
ds.FlippingPoliciesForHostFunc = func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint,
err error,
) {
flippingCallCount++
flippingIncomingResults = incomingResults
return []uint{3}, nil, nil
}
// Track that pre-computed newlyPassingPolicyIDs is forwarded to RecordPolicyQueryExecutions.
var recordedNewlyPassing []uint
ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, gotHost *fleet.Host, results map[uint]*bool, updated time.Time,
deferred bool, newlyPassingPolicyIDs []uint,
) error {
recordedResults = results
recordedNewlyPassing = newlyPassingPolicyIDs
host = gotHost
return nil
}
flippingCallCount = 0
// Record a query execution.
err = svc.SubmitDistributedQueryResults(
ctx,
@ -3726,6 +3744,14 @@ func TestPolicyWebhooks(t *testing.T) {
require.NotNil(t, recordedResults[3])
require.False(t, *recordedResults[3])
// Verify FlippingPoliciesForHost was called exactly once with all 3 policy results.
require.Equal(t, 1, flippingCallCount, "FlippingPoliciesForHost should be called exactly once per SubmitDistributedQueryResults")
require.Len(t, flippingIncomingResults, 3, "FlippingPoliciesForHost should receive all policy results")
// Verify pre-computed newlyPassingPolicyIDs was forwarded (FlippingPoliciesForHost returned nil for newPassing,
// but the caller normalizes it to a non-nil empty slice).
require.NotNil(t, recordedNewlyPassing, "pre-computed newlyPassingPolicyIDs should be forwarded to RecordPolicyQueryExecutions")
require.Empty(t, recordedNewlyPassing)
cmpSets := func(expSets map[uint][]fleet.PolicySetHost) error {
actualSets, err := failingPolicySet.ListSets()
if err != nil {
@ -3805,9 +3831,12 @@ func TestPolicyWebhooks(t *testing.T) {
ds.FlippingPoliciesForHostFunc = func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint,
err error,
) {
flippingCallCount++
flippingIncomingResults = incomingResults
return []uint{1}, []uint{3}, nil
}
flippingCallCount = 0
// Record another query execution.
err = svc.SubmitDistributedQueryResults(
ctx,
@ -3832,6 +3861,12 @@ func TestPolicyWebhooks(t *testing.T) {
require.NotNil(t, recordedResults[3])
require.True(t, *recordedResults[3])
// Verify single call and correct forwarding when there are actual flips.
require.Equal(t, 1, flippingCallCount, "FlippingPoliciesForHost should be called exactly once")
require.Len(t, flippingIncomingResults, 3)
require.NotNil(t, recordedNewlyPassing)
require.ElementsMatch(t, []uint{3}, recordedNewlyPassing, "newlyPassingPolicyIDs should be forwarded from FlippingPoliciesForHost")
assert.Eventually(t, func() bool {
err = cmpSets(map[uint][]fleet.PolicySetHost{
1: {{