diff --git a/changes/42836-deduplicate-flipping-policies-queries b/changes/42836-deduplicate-flipping-policies-queries new file mode 100644 index 0000000000..95f9d838c2 --- /dev/null +++ b/changes/42836-deduplicate-flipping-policies-queries @@ -0,0 +1 @@ +- Reduced redundant database queries during policy result submission by computing flipping policies once per host check-in instead of multiple times. diff --git a/server/datastore/mysql/conditional_access_bypass_test.go b/server/datastore/mysql/conditional_access_bypass_test.go index 47289a4389..779db2b50e 100644 --- a/server/datastore/mysql/conditional_access_bypass_test.go +++ b/server/datastore/mysql/conditional_access_bypass_test.go @@ -306,7 +306,7 @@ func testConditionalAccessBypassDeviceWithBlockingPolicy(t *testing.T, ds *Datas require.NoError(t, err) // Record a failing result for this policy on the host - err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(false)}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(false)}, time.Now(), false, nil) require.NoError(t, err) // Bypass should fail because the host has a failing CA-enabled policy @@ -365,7 +365,7 @@ func testConditionalAccessBypassAllowedWithNonCAFailingCriticalPolicy(t *testing err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{ caPolicy.ID: ptr.Bool(true), // passing nonCAPolicy.ID: ptr.Bool(false), // failing - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) // Bypass must succeed: the only failing policy is not CA-enabled @@ -415,7 +415,7 @@ func testConditionalAccessBypassAllowedWithCAEnabledNonCriticalPolicy(t *testing err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{ nonCriticalCAPolicy.ID: ptr.Bool(false), // failing - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) // Bypass must succeed: policy is CA-enabled but not critical diff --git a/server/datastore/mysql/hosts_test.go b/server/datastore/mysql/hosts_test.go index 4ee35c3de3..1ca27a69b2 100644 --- a/server/datastore/mysql/hosts_test.go +++ b/server/datastore/mysql/hosts_test.go @@ -4101,8 +4101,8 @@ func testHostsListByPolicy(t *testing.T, ds *Datastore) { require.Len(t, hosts, 0) // Make one host pass the policy and another not pass - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{1: ptr.Bool(true)}, time.Now(), false)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{1: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{1: new(true)}, time.Now(), false, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{1: new(false)}, time.Now(), false, nil)) hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{PolicyIDFilter: &p.ID, PolicyResponseFilter: ptr.Bool(true)}, 1) require.Len(t, hosts, 1) @@ -5105,18 +5105,18 @@ func testHostsListFailingPolicies(t *testing.T, ds *Datastore) { assert.Zero(t, *h2.HostIssues.CriticalVulnerabilitiesCount) assert.Zero(t, h2.HostIssues.TotalIssuesCount) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(true)}, time.Now(), false, nil)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(false), p2.ID: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(false), p2.ID: new(false)}, time.Now(), false, nil)) checkHostIssues(t, ds, hosts, filter, h2.ID, 2) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(true), p2.ID: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(true), p2.ID: new(false)}, time.Now(), false, nil)) checkHostIssues(t, ds, hosts, filter, h2.ID, 1) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(true), p2.ID: new(true)}, time.Now(), false, nil)) checkHostIssues(t, ds, hosts, filter, h2.ID, 0) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(false)}, time.Now(), false, nil)) checkHostIssues(t, ds, hosts, filter, h1.ID, 1) checkHostIssuesWithOpts(t, ds, filter, h1.ID, fleet.HostListOptions{DisableIssues: true}, 0) @@ -5183,8 +5183,8 @@ func testHostsReadsLessRows(t *testing.T, ds *Datastore) { }) require.NoError(t, err) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), false)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(true)}, time.Now(), false, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(false)}, time.Now(), false, nil)) prevRead := getReads(t, ds) h1WithExtras, err := ds.Host(context.Background(), h1.ID) @@ -8969,7 +8969,7 @@ func testHostsDeleteHosts(t *testing.T, ds *Datastore) { _, err = ds.writer(context.Background()).Exec(`INSERT INTO query_results (host_id, query_id, last_fetched, data) VALUES (?, ?, ?, ?)`, host.ID, policy.ID, time.Now(), `{"foo": "bar"}`) require.NoError(t, err) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil)) // Update host_mdm. err = ds.SetOrUpdateMDMData(context.Background(), host.ID, false, true, "foo.mdm.example.com", false, "", "", false) require.NoError(t, err) @@ -9727,7 +9727,7 @@ func testFailingPoliciesCount(t *testing.T, ds *Datastore) { for _, tc := range testCases { if len(tc.policyEx) != 0 { - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, tc.host, tc.policyEx, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, tc.host, tc.policyEx, time.Now(), false, nil)) } actual, err := ds.FailingPoliciesCount(ctx, tc.host) require.NoError(t, err) @@ -9769,7 +9769,7 @@ func testHostsRecordNoPolicies(t *testing.T, ds *Datastore) { assert.Zero(t, h2.HostIssues.TotalIssuesCount) policyUpdatedAt := initialTime.Add(1 * time.Hour) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, nil, policyUpdatedAt, false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, nil, policyUpdatedAt, false, nil)) hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{}, 2) require.Len(t, hosts, 2) @@ -9916,7 +9916,7 @@ func testHostOrder(t *testing.T, ds *Datastore) { } require.NoError( t, ds.RecordPolicyQueryExecutions( - context.Background(), createdHosts[i], results, time.Now(), false, + context.Background(), createdHosts[i], results, time.Now(), false, nil, ), ) } @@ -11791,8 +11791,8 @@ func testHostHealth(t *testing.T, ds *Datastore) { failingPolicy, err := ds.NewGlobalPolicy(context.Background(), &u.ID, fleet.PolicyPayload{QueryID: &q.ID}) require.NoError(t, err) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h, map[uint]*bool{passingPolicy.ID: ptr.Bool(true)}, time.Now(), false)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h, map[uint]*bool{failingPolicy.ID: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h, map[uint]*bool{passingPolicy.ID: new(true)}, time.Now(), false, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h, map[uint]*bool{failingPolicy.ID: new(false)}, time.Now(), false, nil)) // set up vulnerable software software := []fleet.Software{ @@ -12269,7 +12269,7 @@ func testUpdateHostIssues(t *testing.T, ds *Datastore) { require.NoError( // RecordPolicyQueryExecutions should call UpdateHostIssuesFailingPolicies, so we don't have to t, ds.RecordPolicyQueryExecutions( - context.Background(), hosts[i], results, time.Now(), false, + context.Background(), hosts[i], results, time.Now(), false, nil, ), ) } diff --git a/server/datastore/mysql/labels_test.go b/server/datastore/mysql/labels_test.go index 4880a7f9c2..172b34eab4 100644 --- a/server/datastore/mysql/labels_test.go +++ b/server/datastore/mysql/labels_test.go @@ -1604,9 +1604,9 @@ func testListHostsInLabelIssues(t *testing.T, ds *Datastore) { assert.Zero(t, *h2.HostIssues.CriticalVulnerabilitiesCount) assert.Zero(t, h2.HostIssues.TotalIssuesCount) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(true)}, time.Now(), false, nil)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(false), p2.ID: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(false), p2.ID: new(false)}, time.Now(), false, nil)) checkLabelHostIssues(t, ds, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 2, 0) // Add a critical vulnerability @@ -1676,13 +1676,13 @@ func testListHostsInLabelIssues(t *testing.T, ds *Datastore) { assert.NoError(t, ds.UpdateHostIssuesVulnerabilities(ctx)) checkLabelHostIssues(t, ds, l1.ID, filter, hosts[6].ID, fleet.HostListOptions{}, 0, 4) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(true), p2.ID: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(true), p2.ID: new(false)}, time.Now(), false, nil)) checkLabelHostIssues(t, ds, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 1, 1) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(true), p2.ID: new(true)}, time.Now(), false, nil)) checkLabelHostIssues(t, ds, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 0, 1) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(false)}, time.Now(), false, nil)) checkLabelHostIssues(t, ds, l1.ID, filter, h1.ID, fleet.HostListOptions{}, 1, 1) checkLabelHostIssues(t, ds, l1.ID, filter, h1.ID, fleet.HostListOptions{DisableIssues: true}, 0, 0) diff --git a/server/datastore/mysql/policies.go b/server/datastore/mysql/policies.go index 3d7b12cace..12d99ef9a2 100644 --- a/server/datastore/mysql/policies.go +++ b/server/datastore/mysql/policies.go @@ -608,12 +608,20 @@ func filterNotExecuted(results map[uint]*bool) map[uint]bool { return filtered } -func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error { +func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool, newlyPassingPolicyIDs []uint) error { // Identify policies that flipped failing -> passing for this host using current incoming results. // We compute this before updating policy_membership so we compare against the previous state. - _, newPassing, err := ds.FlippingPoliciesForHost(ctx, host.ID, results) - if err != nil { - return err + // When newlyPassingPolicyIDs is non-nil, the caller has already computed flipping policies + // (e.g. SubmitDistributedQueryResults computes it once for all consumers) so we reuse that result. + var newPassing []uint + if newlyPassingPolicyIDs != nil { + newPassing = newlyPassingPolicyIDs + } else { + var err error + _, newPassing, err = ds.FlippingPoliciesForHost(ctx, host.ID, results) + if err != nil { + return err + } } if len(newPassing) > 0 { slices.Sort(newPassing) @@ -645,7 +653,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *flee // semantically equivalent, even though here it processes a single host and // in async mode it processes a batch of hosts). - err = ds.withTx(ctx, func(tx sqlx.ExtContext) error { + err := ds.withTx(ctx, func(tx sqlx.ExtContext) error { if len(results) > 0 { query := fmt.Sprintf( `INSERT INTO policy_membership (updated_at, policy_id, host_id, passes) diff --git a/server/datastore/mysql/policies_test.go b/server/datastore/mysql/policies_test.go index 0c7bd38260..89e921bfab 100644 --- a/server/datastore/mysql/policies_test.go +++ b/server/datastore/mysql/policies_test.go @@ -460,14 +460,14 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) { require.NotNil(t, p2.AuthorID) assert.Equal(t, user1.ID, *p2.AuthorID) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred)) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: new(true)}, time.Now(), deferred, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: new(true)}, time.Now(), deferred, nil)) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: nil}, time.Now(), deferred)) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred)) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: nil}, time.Now(), deferred, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: new(false)}, time.Now(), deferred, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: new(true)}, time.Now(), deferred, nil)) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: nil}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: nil}, time.Now(), deferred, nil)) require.NoError(t, ds.UpdateHostPolicyCounts(ctx)) @@ -483,8 +483,8 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) { assert.Equal(t, uint(0), policies[1].PassingHostCount) assert.Equal(t, uint(0), policies[1].FailingHostCount) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred)) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: ptr.Bool(false)}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: new(false)}, time.Now(), deferred, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: new(false)}, time.Now(), deferred, nil)) require.NoError(t, ds.UpdateHostPolicyCounts(ctx)) @@ -500,6 +500,30 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) { assert.Equal(t, uint(0), policies[1].PassingHostCount) assert.Equal(t, uint(1), policies[1].FailingHostCount) + // Test with pre-computed newlyPassingPolicyIDs (non-nil) to exercise the path where + // RecordPolicyQueryExecutions skips calling FlippingPoliciesForHost internally. + // host1 currently has p.ID=failing, so flipping to passing with pre-computed IDs. + require.NoError(t, ds.RecordPolicyQueryExecutions( + ctx, host1, map[uint]*bool{p.ID: new(true)}, time.Now(), deferred, []uint{p.ID}, + )) + // Also test with an empty (but non-nil) slice, which means "already computed, no newly passing". + require.NoError(t, ds.RecordPolicyQueryExecutions( + ctx, host2, map[uint]*bool{p2.ID: new(true)}, time.Now(), deferred, []uint{}, + )) + + require.NoError(t, ds.UpdateHostPolicyCounts(ctx)) + policies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{}) + require.NoError(t, err) + require.Len(t, policies, 2) + // host1: p now passing again (was failing), host2: p still passing + assert.Equal(t, p.ID, policies[0].ID) + assert.Equal(t, uint(2), policies[0].PassingHostCount) + assert.Equal(t, uint(0), policies[0].FailingHostCount) + // host2: p2 now passing (was failing) + assert.Equal(t, p2.ID, policies[1].ID) + assert.Equal(t, uint(1), policies[1].PassingHostCount) + assert.Equal(t, uint(0), policies[1].FailingHostCount) + policy, err := ds.Policy(ctx, policies[0].ID) require.NoError(t, err) assert.Equal(t, policies[0], policy) @@ -553,9 +577,9 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) { require.NoError(t, err) // create some policy results - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host3, map[uint]*bool{t1pol.ID: ptr.Bool(true), p.ID: ptr.Bool(true), p2.ID: ptr.Bool(false)}, time.Now(), deferred)) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host4, map[uint]*bool{t2pol.ID: ptr.Bool(false), t2pol2.ID: ptr.Bool(true), p.ID: ptr.Bool(false)}, time.Now(), deferred)) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host5, map[uint]*bool{t2pol.ID: ptr.Bool(true), t2pol2.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host3, map[uint]*bool{t1pol.ID: new(true), p.ID: new(true), p2.ID: new(false)}, time.Now(), deferred, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host4, map[uint]*bool{t2pol.ID: new(false), t2pol2.ID: new(true), p.ID: new(false)}, time.Now(), deferred, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host5, map[uint]*bool{t2pol.ID: new(true), t2pol2.ID: new(true), p2.ID: new(true)}, time.Now(), deferred, nil)) require.NoError(t, ds.UpdateHostPolicyCounts(ctx)) @@ -1077,7 +1101,7 @@ func testListMergedTeamPolicies(t *testing.T, ds *Datastore) { &fleet.Host{OsqueryHostID: ptr.String("host1"), NodeKey: ptr.String(fmt.Sprint("host1", 1)), TeamID: nil}) require.NoError(t, err) - err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{gpol.ID: ptr.Bool(true)}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{gpol.ID: new(true)}, time.Now(), false, nil) require.NoError(t, err) err = ds.UpdateHostPolicyCounts(context.Background()) @@ -1100,7 +1124,7 @@ func testListMergedTeamPolicies(t *testing.T, ds *Datastore) { err = ds.AddHostsToTeam(ctx, fleet.NewAddHostsToTeamParams(&team1.ID, []uint{host.ID})) require.NoError(t, err) - err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{team1policy.ID: ptr.Bool(true)}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{team1policy.ID: new(true)}, time.Now(), false, nil) require.NoError(t, err) err = ds.UpdateHostPolicyCounts(context.Background()) @@ -1447,7 +1471,7 @@ func testPolicyQueriesForHost(t *testing.T, ds *Datastore) { assert.Equal(t, q.Query, queries[fmt.Sprint(q.ID)]) // Team policy ran with failing result. - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{tp.ID: ptr.Bool(false), gp.ID: nil}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{tp.ID: new(false), gp.ID: nil}, time.Now(), false, nil)) policies, err := ds.ListPoliciesForHost(context.Background(), host1) require.NoError(t, err) @@ -1490,7 +1514,7 @@ func testPolicyQueriesForHost(t *testing.T, ds *Datastore) { assert.Equal(t, "", policies[0].Response) // Global policy ran with passing result. - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{gp.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{gp.ID: new(true)}, time.Now(), false, nil)) policies, err = ds.ListPoliciesForHost(context.Background(), host2) require.NoError(t, err) @@ -1511,7 +1535,7 @@ func testPolicyQueriesForHost(t *testing.T, ds *Datastore) { require.NoError(t, err) require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{uint(id): nil}, //nolint:gosec // dismiss G115 - time.Now(), false)) + time.Now(), false, nil)) policies, err = ds.ListPoliciesForHost(context.Background(), host2) require.NoError(t, err) @@ -1556,7 +1580,7 @@ func testPoliciesByID(t *testing.T, ds *Datastore) { err = ds.SavePolicy(context.Background(), policy2, false, false) require.NoError(t, err) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{policy1.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{policy1.ID: new(true)}, time.Now(), false, nil)) require.NoError(t, ds.UpdateHostPolicyCounts(context.Background())) policiesByID, err := ds.PoliciesByID(context.Background(), []uint{1, 2}) @@ -1634,10 +1658,10 @@ func testTeamPolicyTransfer(t *testing.T, ds *Datastore) { }) require.NoError(t, err) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: new(false), globalPolicy.ID: new(true)}, time.Now(), false, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: new(true), globalPolicy.ID: new(true)}, time.Now(), false, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: new(false), globalPolicy.ID: new(true)}, time.Now(), false, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: new(true), globalPolicy.ID: new(true)}, time.Now(), false, nil)) require.NoError(t, ds.UpdateHostPolicyCounts(ctx)) @@ -1691,7 +1715,7 @@ func testTeamPolicyTransfer(t *testing.T, ds *Datastore) { checkPassingCount(0, 0, 1, 1) // team policies are removed if the host is re-enrolled without a team - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: new(true), globalPolicy.ID: new(true)}, time.Now(), false, nil)) checkPassingCount(1, 0, 2, 2) // all host policies are removed when a host is re-enrolled @@ -2151,7 +2175,7 @@ func testApplyPolicySpecWithQueryPlatformChanges(t *testing.T, ds *Datastore) { for _, pol := range polsByName { res[pol.ID] = ptr.Bool(false) } - err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil) require.NoError(t, err) } for _, h := range globalHosts { @@ -2159,7 +2183,7 @@ func testApplyPolicySpecWithQueryPlatformChanges(t *testing.T, ds *Datastore) { for _, pol := range globalPolsByName { res[pol.ID] = ptr.Bool(false) } - err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil) require.NoError(t, err) } err = ds.UpdateHostPolicyCounts(ctx) @@ -2529,17 +2553,17 @@ func testCachedPolicyCountDeletesOnPolicyChange(t *testing.T, ds *Datastore) { // teamHost and globalHost fail all policies require.NoError( t, ds.RecordPolicyQueryExecutions( - ctx, teamHost, map[uint]*bool{globalPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(false)}, time.Now(), false, + ctx, teamHost, map[uint]*bool{globalPolicy.ID: new(false), globalPolicy.ID: new(false)}, time.Now(), false, nil, ), ) require.NoError( t, ds.RecordPolicyQueryExecutions( - ctx, teamHost, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), teamPolicy.ID: ptr.Bool(false)}, time.Now(), false, + ctx, teamHost, map[uint]*bool{teamPolicy.ID: new(false), teamPolicy.ID: new(false)}, time.Now(), false, nil, ), ) require.NoError( t, ds.RecordPolicyQueryExecutions( - ctx, globalHost, map[uint]*bool{globalPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(false)}, time.Now(), false, + ctx, globalHost, map[uint]*bool{globalPolicy.ID: new(false), globalPolicy.ID: new(false)}, time.Now(), false, nil, ), ) @@ -2723,7 +2747,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) { require.Empty(t, newPassing) // because this would be the first run. // Record the above executions. - err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil) require.NoError(t, err) // incoming policy 1 with passing result: no => yes @@ -2738,7 +2762,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) { require.Equal(t, []uint{p1.ID}, newPassing) // Record the above executions. - err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil) require.NoError(t, err) // incoming policy 1 with passing result: yes => yes @@ -2753,7 +2777,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) { require.Empty(t, newPassing) // Record the above executions. - err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil) require.NoError(t, err) // incoming policy 1 failed to execute: yes => no @@ -2777,7 +2801,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) { require.Empty(t, newPassing) // Record the above executions. - err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil) require.NoError(t, err) // incoming pfailed again failed to execute: --- -> --- @@ -2798,7 +2822,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) { require.Empty(t, newPassing) // Record the above executions. - err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil) require.NoError(t, err) // incoming policy 4 with first new failing result: --- => no @@ -2932,7 +2956,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) { // also record a result for linux policy res[polsByName["t2"].ID] = ptr.Bool(true) } - err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil) require.NoError(t, err) } for i, h := range globalHosts { @@ -2943,7 +2967,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) { // also record a result for linux policy res[polsByName["g2"].ID] = ptr.Bool(true) } - err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil) require.NoError(t, err) } @@ -3066,9 +3090,9 @@ func testPolicyViolationDays(t *testing.T, ds *Datastore) { require.NoError(t, ds.InitializePolicyViolationDays(ctx)) // sets starting violation count to zero // initialize policy statuses: 1 failling, 2 passing - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[0], map[uint]*bool{pol.ID: ptr.Bool(false)}, then, false)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[1], map[uint]*bool{pol.ID: ptr.Bool(true)}, then, false)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[2], map[uint]*bool{pol.ID: ptr.Bool(true)}, then, false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[0], map[uint]*bool{pol.ID: new(false)}, then, false, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[1], map[uint]*bool{pol.ID: new(true)}, then, false, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[2], map[uint]*bool{pol.ID: new(true)}, then, false, nil)) // setup db for test: starting counts zero, more than 24h since last updated, one outstanding violation require.NoError(t, setStatsTimestampDB(time.Now().Add(-25*time.Hour))) @@ -3094,7 +3118,7 @@ func testPolicyViolationDays(t *testing.T, ds *Datastore) { // leave counts at zero for next test // setup for test: starting count zero, more than 24h since last updated, add second outstanding violation - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil)) require.NoError(t, setStatsTimestampDB(time.Now().Add(-25*time.Hour))) require.NoError(t, ds.IncrementPolicyViolationDays(ctx)) actual, possible, err = amountPolicyViolationDaysDB(ctx, ds.reader(ctx)) @@ -3106,7 +3130,7 @@ func testPolicyViolationDays(t *testing.T, ds *Datastore) { // leave counts at 2 actual and 3 possible for next test // setup for test: starting counts at 2 actual and 3 possible, more than 24h since last updated, resolve one outstaning violation - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: new(true)}, time.Now(), false, nil)) require.NoError(t, setStatsTimestampDB(time.Now().Add(-25*time.Hour))) require.NoError(t, ds.IncrementPolicyViolationDays(ctx)) actual, possible, err = amountPolicyViolationDaysDB(ctx, ds.reader(ctx)) @@ -3196,7 +3220,7 @@ func testPolicyCleanupPolicyMembership(t *testing.T, ds *Datastore) { polsByName["p2"].ID: ptr.Bool(true), polsByName["p3"].ID: ptr.Bool(true), } - err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil) require.NoError(t, err) } require.NoError(t, ds.writer(ctx).Get(&count, "select COUNT(*) from host_issues WHERE total_issues_count > 0")) @@ -3315,7 +3339,7 @@ func testDeleteAllPolicyMemberships(t *testing.T, ds *Datastore) { host, map[uint]*bool{globalPolicy.ID: ptr.Bool(false)}, time.Now(), - false, + false, nil, ) require.NoError(t, err) @@ -3378,9 +3402,9 @@ func testOutdatedAutomationBatch(t *testing.T, ds *Datastore) { pol2, err := ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: "policy2"}) require.NoError(t, err) - err = ds.RecordPolicyQueryExecutions(ctx, h1, map[uint]*bool{pol1.ID: ptr.Bool(false), pol2.ID: ptr.Bool(true)}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, h1, map[uint]*bool{pol1.ID: new(false), pol2.ID: new(true)}, time.Now(), false, nil) require.NoError(t, err) - err = ds.RecordPolicyQueryExecutions(ctx, h2, map[uint]*bool{pol1.ID: ptr.Bool(false), pol2.ID: ptr.Bool(false)}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, h2, map[uint]*bool{pol1.ID: new(false), pol2.ID: new(false)}, time.Now(), false, nil) require.NoError(t, err) batch, err := ds.OutdatedAutomationBatch(ctx) @@ -3615,7 +3639,7 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) { res := map[uint]*bool{ policy.ID: ptr.Bool(true), } - err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false, nil) require.NoError(t, err) } @@ -3671,7 +3695,7 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) { res := map[uint]*bool{ policy.ID: result, } - err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false, nil) require.NoError(t, err) } @@ -3720,7 +3744,7 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) { policy.ID: ptr.Bool(false), policy2.ID: ptr.Bool(true), } - err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false, nil) require.NoError(t, err) } for _, h := range teamHosts { @@ -3728,7 +3752,7 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) { policy.ID: ptr.Bool(false), policy2.ID: ptr.Bool(true), } - err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false, nil) require.NoError(t, err) } @@ -4101,25 +4125,25 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) { err = ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{ team1Policy1.ID: ptr.Bool(true), team1Policy2.ID: ptr.Bool(false), - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) err = ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{ team2Policy1.ID: ptr.Bool(false), team2Policy2.ID: ptr.Bool(true), - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) err = ds.RecordPolicyQueryExecutions(ctx, host3, map[uint]*bool{ team2Policy1.ID: ptr.Bool(true), team2Policy2.ID: ptr.Bool(true), - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) err = ds.RecordPolicyQueryExecutions(ctx, host5, map[uint]*bool{ team1Policy1.ID: ptr.Bool(false), team1Policy2.ID: ptr.Bool(false), - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) team1Policies, err := ds.GetCalendarPolicies(ctx, team1.ID) @@ -4187,7 +4211,7 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) { err = ds.RecordPolicyQueryExecutions(ctx, host4, map[uint]*bool{ team1Policy1.ID: ptr.Bool(false), team1Policy2.ID: ptr.Bool(false), - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) hostsTeam1, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team1.ID, []uint{team1Policies[0].ID}, nil) @@ -4271,7 +4295,7 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) { ctx, host2, map[uint]*bool{ team2Policy1.ID: nil, team2Policy2.ID: nil, - }, time.Now(), false, + }, time.Now(), false, nil, ) require.NoError(t, err) @@ -4301,7 +4325,7 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) { err = ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{ team2Policy1.ID: ptr.Bool(true), team2Policy2.ID: ptr.Bool(true), - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) hostsTeam2, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team2.ID, []uint{team2Policies[0].ID, team2Policies[1].ID}, nil) @@ -4396,7 +4420,7 @@ func testGetTeamHostsPolicyMembershipsEmailPriority(t *testing.T, ds *Datastore) }) require.NoError(t, err) // Make the host fail the calendar policy so it always appears in results. - err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{calendarPolicy.ID: ptr.Bool(false)}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{calendarPolicy.ID: new(false)}, time.Now(), false, nil) require.NoError(t, err) return h } @@ -5158,7 +5182,7 @@ func testApplyPolicySpecWithInstallers(t *testing.T, ds *Datastore) { err = ds.RecordPolicyQueryExecutions(ctx, host1Team1, map[uint]*bool{ policy1Team1.ID: ptr.Bool(false), vppPolicy1Team1.ID: ptr.Bool(false), - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) err = ds.UpdateHostPolicyCounts(ctx) require.NoError(t, err) @@ -5417,7 +5441,7 @@ func testApplyPolicySpecWithInstallers(t *testing.T, ds *Datastore) { err = ds.RecordPolicyQueryExecutions(ctx, host1Team1, map[uint]*bool{ policy1Team1.ID: ptr.Bool(false), vppPolicy1Team1.ID: ptr.Bool(false), - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) err = ds.UpdateHostPolicyCounts(ctx) require.NoError(t, err) @@ -5604,7 +5628,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) { globalPolicy2.ID: ptr.Bool(false), policy0NoTeam.ID: ptr.Bool(true), policy3NoTeam.ID: ptr.Bool(false), - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) // Results for host1Team1 @@ -5612,7 +5636,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) { globalPolicy1.ID: ptr.Bool(true), globalPolicy2.ID: nil, // failed to execute, e.g. typo on SQL. policy1Team1.ID: ptr.Bool(true), - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) // Results for host2Team1 @@ -5620,7 +5644,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) { globalPolicy1.ID: ptr.Bool(false), globalPolicy2.ID: ptr.Bool(true), policy1Team1.ID: ptr.Bool(false), - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) // Results for host3Team2 @@ -5628,7 +5652,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) { globalPolicy1.ID: ptr.Bool(true), policy2Team2.ID: ptr.Bool(true), policy4Team2.ID: ptr.Bool(false), - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) // Results for host5NoTeam @@ -5637,7 +5661,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) { globalPolicy2.ID: ptr.Bool(false), policy0NoTeam.ID: ptr.Bool(false), policy3NoTeam.ID: ptr.Bool(false), - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) err = ds.UpdateHostPolicyCounts(ctx) @@ -6155,7 +6179,7 @@ func testClearAutoInstallPolicyStatusForHost(t *testing.T, ds *Datastore) { policy1.ID: ptr.Bool(true), policy2.ID: ptr.Bool(false), // software isn't installed on host, so Fleet should install it policy3.ID: ptr.Bool(false), // software isn't installed on host, so Fleet should install it - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) hostPolicies, err := ds.ListPoliciesForHost(ctx, host) @@ -6378,7 +6402,7 @@ func testPolicyLabelMembershipCleanup(t *testing.T, ds *Datastore) { // Record policy results for all hosts for _, h := range []*fleet.Host{hostNoLabels, hostLabel1, hostLabel2, hostLabelBoth} { - err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil) require.NoError(t, err) } @@ -6408,7 +6432,7 @@ func testPolicyLabelMembershipCleanup(t *testing.T, ds *Datastore) { // Re-record membership for all hosts to test exclude labels for _, h := range []*fleet.Host{hostNoLabels, hostLabel1, hostLabel2, hostLabelBoth} { - err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil) require.NoError(t, err) } wantHostsByPol[policy.Name] = []uint{hostNoLabels.ID, hostLabel1.ID, hostLabel2.ID, hostLabelBoth.ID} @@ -6426,7 +6450,7 @@ func testPolicyLabelMembershipCleanup(t *testing.T, ds *Datastore) { // Test ApplyPolicySpecs with label changes // First, re-record membership for all hosts for _, h := range []*fleet.Host{hostNoLabels, hostLabel1, hostLabel2, hostLabelBoth} { - err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil) require.NoError(t, err) } wantHostsByPol[policy.Name] = []uint{hostNoLabels.ID, hostLabel1.ID, hostLabel2.ID, hostLabelBoth.ID} @@ -6464,7 +6488,7 @@ func testPolicyLabelMembershipCleanup(t *testing.T, ds *Datastore) { // Record membership for all hosts with label1 for _, h := range []*fleet.Host{hostLabel1, hostLabelBoth, hostWinLabel1, hostMacLabel1} { - err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy2.ID: ptr.Bool(true)}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy2.ID: new(true)}, time.Now(), false, nil) require.NoError(t, err) } @@ -6551,8 +6575,8 @@ func testDeletePolicyWithSoftwareActivatesNextActivity(t *testing.T, ds *Datasto require.NoError(t, err) // record a failing policy for both hosts, would enqueue the install - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostNoTm, map[uint]*bool{policyNoTm.ID: ptr.Bool(false)}, time.Now(), false)) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostTm, map[uint]*bool{policyTm.ID: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostNoTm, map[uint]*bool{policyNoTm.ID: new(false)}, time.Now(), false, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostTm, map[uint]*bool{policyTm.ID: new(false)}, time.Now(), false, nil)) // simulate the work of "processSoftwareForNewlyFailingPolicies" installUUIDNoTm, err := ds.InsertSoftwareInstallRequest(ctx, hostNoTm.ID, installerIDNoTm, @@ -6645,8 +6669,8 @@ func testDeletePolicyWithScriptActivatesNextActivity(t *testing.T, ds *Datastore // record a failing policy for both hosts, would enqueue the associated // scripts - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostNoTm, map[uint]*bool{policyNoTm.ID: ptr.Bool(false)}, time.Now(), false)) - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostTm, map[uint]*bool{policyTm.ID: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostNoTm, map[uint]*bool{policyNoTm.ID: new(false)}, time.Now(), false, nil)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostTm, map[uint]*bool{policyTm.ID: new(false)}, time.Now(), false, nil)) // simulate the work of "processScriptsForNewlyFailingPolicies" hsrPolicyNoTm, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{ @@ -6713,7 +6737,7 @@ func testSimultaneousSavePolicy(t *testing.T, ds *Datastore) { for _, policy := range policies { host1Results[policy.ID] = ptr.Bool(true) } - err = ds.RecordPolicyQueryExecutions(ctx, host1, host1Results, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, host1, host1Results, time.Now(), false, nil) require.NoError(t, err) // Run simultaneous @@ -6754,7 +6778,7 @@ func testIsPolicyFailing(t *testing.T, ds *Datastore) { // Exists with passes = NULL // failing - err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: nil}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: nil}, time.Now(), false, nil) require.NoError(t, err) isFailing, err = ds.IsPolicyFailing(ctx, policy.ID, host.ID) @@ -6763,7 +6787,7 @@ func testIsPolicyFailing(t *testing.T, ds *Datastore) { // exists with passes = false // failing - err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(false)}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(false)}, time.Now(), false, nil) require.NoError(t, err) isFailing, err = ds.IsPolicyFailing(ctx, policy.ID, host.ID) @@ -6772,7 +6796,7 @@ func testIsPolicyFailing(t *testing.T, ds *Datastore) { // exists with passes = true // Not failing - err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil) require.NoError(t, err) isFailing, err = ds.IsPolicyFailing(ctx, policy.ID, host.ID) @@ -6821,7 +6845,7 @@ func testResetAttemptsOnFailingToPassingSync(t *testing.T, ds *Datastore) { require.NoError(t, err) // p1 will be failing - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: new(false)}, time.Now(), false, nil)) // Create rows with attempt_number > 0 and attempt_number IS NULL (pending) // p1 - completed attempt @@ -6852,7 +6876,7 @@ func testResetAttemptsOnFailingToPassingSync(t *testing.T, ds *Datastore) { require.NoError(t, err) // p1 is now passing - err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: new(true), p2.ID: new(true)}, time.Now(), false, nil) require.NoError(t, err) // p1 rows should be reset to 0 (both completed and pending) @@ -6903,7 +6927,7 @@ func testResetAttemptsOnFailingToPassingAsync(t *testing.T, ds *Datastore) { require.NoError(t, err) // p1 is failing - require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: new(false)}, time.Now(), false, nil)) // Create rows with attempt_number > 0 and attempt_number IS NULL (pending) // p1 - completed attempt @@ -7139,7 +7163,7 @@ func testBatchedPolicyMembershipCleanup(t *testing.T, ds *Datastore) { // Record failing results for all hosts so they all have policy_membership rows and host_issues entries. for _, h := range hosts { - err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false) + err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil) require.NoError(t, err) } @@ -7228,7 +7252,7 @@ func testBatchedPolicyMembershipCleanupOnPolicyUpdate(t *testing.T, ds *Datastor // Record results for all hosts (simulating results arriving before platform filter applied). allHosts := append([]*fleet.Host{winHost}, linuxHosts...) for _, h := range allHosts { - err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false) + err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil) require.NoError(t, err) } @@ -7293,7 +7317,7 @@ func testBatchedPolicyMembershipCleanupOnPolicyUpdate(t *testing.T, ds *Datastor // Record policy results for all label-test hosts so policy_membership is populated. labelHosts := append([]*fleet.Host{lblHost}, nonLblHosts...) for _, h := range labelHosts { - err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{lblPol.ID: ptr.Bool(false)}, time.Now(), false) + err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{lblPol.ID: new(false)}, time.Now(), false, nil) require.NoError(t, err) } @@ -7356,7 +7380,7 @@ func testApplyPolicySpecsNeedsFullMembershipCleanupFlag(t *testing.T, ds *Datast hosts[i] = h } for _, h := range hosts { - err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false) + err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil) require.NoError(t, err) } @@ -7416,7 +7440,7 @@ func testCleanupPolicyMembershipCrashRecovery(t *testing.T, ds *Datastore) { recordResults := func(t *testing.T, hosts []*fleet.Host, polID uint) { t.Helper() for _, h := range hosts { - err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{polID: ptr.Bool(false)}, time.Now(), false) + err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{polID: new(false)}, time.Now(), false, nil) require.NoError(t, err) } } diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 436971c401..967fbc86fe 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -1092,7 +1092,10 @@ type Datastore interface { // RecordPolicyQueryExecutions records the execution results of the policies for the given host. // Even if `results` is empty, the host's `policy_updated_at` will be updated. - RecordPolicyQueryExecutions(ctx context.Context, host *Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error + // If newlyPassingPolicyIDs is non-nil, it contains the IDs of policies that flipped from failing to passing + // and is used directly instead of calling FlippingPoliciesForHost internally. This allows callers that have + // already computed flipping policies to avoid a redundant database query. + RecordPolicyQueryExecutions(ctx context.Context, host *Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool, newlyPassingPolicyIDs []uint) error // RecordLabelQueryExecutions saves the results of label queries. The results map is a map of label id -> whether or // not the label matches. The time parameter is the timestamp to save with the query execution. diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index c4d7513df9..9ce18ab2a2 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -781,7 +781,7 @@ type UpdateHostRefetchCriticalQueriesUntilFunc func(ctx context.Context, hostID type FlippingPoliciesForHostFunc func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error) -type RecordPolicyQueryExecutionsFunc func(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error +type RecordPolicyQueryExecutionsFunc func(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool, newlyPassingPolicyIDs []uint) error type RecordLabelQueryExecutionsFunc func(ctx context.Context, host *fleet.Host, results map[uint]*bool, t time.Time, deferredSaveHost bool) error @@ -7260,11 +7260,11 @@ func (s *DataStore) FlippingPoliciesForHost(ctx context.Context, hostID uint, in return s.FlippingPoliciesForHostFunc(ctx, hostID, incomingResults) } -func (s *DataStore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error { +func (s *DataStore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool, newlyPassingPolicyIDs []uint) error { s.mu.Lock() s.RecordPolicyQueryExecutionsFuncInvoked = true s.mu.Unlock() - return s.RecordPolicyQueryExecutionsFunc(ctx, host, results, updated, deferredSaveHost) + return s.RecordPolicyQueryExecutionsFunc(ctx, host, results, updated, deferredSaveHost, newlyPassingPolicyIDs) } func (s *DataStore) RecordLabelQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, t time.Time, deferredSaveHost bool) error { diff --git a/server/service/async/async_policy.go b/server/service/async/async_policy.go index bb0f9c0589..6b839ae980 100644 --- a/server/service/async/async_policy.go +++ b/server/service/async/async_policy.go @@ -28,11 +28,11 @@ const ( // redis list will be LTRIM'd if there are more policy IDs than this. var maxRedisPolicyResultsPerHost = 1000 -func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error { +func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool, newlyPassingPolicyIDs []uint) error { cfg := t.taskConfigs[config.AsyncTaskPolicyMembership] if !cfg.Enabled { host.PolicyUpdatedAt = ts - return t.datastore.RecordPolicyQueryExecutions(ctx, host, results, ts, deferred) + return t.datastore.RecordPolicyQueryExecutions(ctx, host, results, ts, deferred, newlyPassingPolicyIDs) } keyList := fmt.Sprintf(policyPassHostKey, host.ID) diff --git a/server/service/async/async_policy_test.go b/server/service/async/async_policy_test.go index d055f6da6e..353acd154b 100644 --- a/server/service/async/async_policy_test.go +++ b/server/service/async/async_policy_test.go @@ -321,7 +321,7 @@ func testRecordPolicyQueryExecutionsSync(t *testing.T, ds *mock.Store, pool flee policyReportedAt := task.GetHostPolicyReportedAt(ctx, host) require.True(t, policyReportedAt.Equal(lastYear)) - err := task.RecordPolicyQueryExecutions(ctx, host, results, now, false) + err := task.RecordPolicyQueryExecutions(ctx, host, results, now, false, nil) require.NoError(t, err) require.True(t, ds.RecordPolicyQueryExecutionsFuncInvoked) ds.RecordPolicyQueryExecutionsFuncInvoked = false @@ -373,7 +373,7 @@ func testRecordPolicyQueryExecutionsAsync(t *testing.T, ds *mock.Store, pool fle policyReportedAt := task.GetHostPolicyReportedAt(ctx, host) require.True(t, policyReportedAt.Equal(lastYear)) - err := task.RecordPolicyQueryExecutions(ctx, host, results, now, false) + err := task.RecordPolicyQueryExecutions(ctx, host, results, now, false, nil) require.NoError(t, err) require.False(t, ds.RecordPolicyQueryExecutionsFuncInvoked) @@ -435,7 +435,7 @@ func testRecordPolicyQueryExecutionsNoPoliciesSync(t *testing.T, ds *mock.Store, policyReportedAt := task.GetHostPolicyReportedAt(ctx, host) require.True(t, policyReportedAt.Equal(lastYear)) - err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false) + err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false, nil) require.NoError(t, err) require.True(t, ds.RecordPolicyQueryExecutionsFuncInvoked) ds.RecordPolicyQueryExecutionsFuncInvoked = false @@ -485,7 +485,7 @@ func testRecordPolicyQueryExecutionsNoPoliciesAsync(t *testing.T, ds *mock.Store policyReportedAt := task.GetHostPolicyReportedAt(ctx, host) require.True(t, policyReportedAt.Equal(lastYear)) - err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false) + err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false, nil) require.NoError(t, err) require.False(t, ds.RecordPolicyQueryExecutionsFuncInvoked) diff --git a/server/service/async/async_test.go b/server/service/async/async_test.go index 78259f347c..7908ed6174 100644 --- a/server/service/async/async_test.go +++ b/server/service/async/async_test.go @@ -88,7 +88,7 @@ func TestRecord(t *testing.T) { ds.AsyncBatchUpdateLabelTimestampFunc = func(ctx context.Context, ids []uint, ts time.Time) error { return nil } - ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error { + ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool, newlyPassingPolicyIDs []uint) error { return nil } ds.AsyncBatchInsertPolicyMembershipFunc = func(ctx context.Context, batch []fleet.PolicyMembershipResult) error { diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 1c7c1f90aa..d8d2102326 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -1223,8 +1223,8 @@ func (s *integrationTestSuite) TestGlobalPolicies() { s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) require.Len(t, listHostsResp.Hosts, 0) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false)) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false)) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: new(true)}, time.Now(), false, nil)) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false, nil)) listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID) listHostsResp = listHostsResponse{} @@ -1844,7 +1844,7 @@ func (s *integrationTestSuite) TestListHosts() { require.NoError( t, - s.ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{globalPolicy0.ID: ptr.Bool(false)}, time.Now(), false), + s.ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{globalPolicy0.ID: new(false)}, time.Now(), false, nil), ) resp = listHostsResponse{} @@ -2109,7 +2109,7 @@ func (s *integrationTestSuite) TestListHosts() { for _, host := range hosts { // All hosts pass the globalPolicy1 err := s.ds.RecordPolicyQueryExecutions( - context.Background(), host, map[uint]*bool{globalPolicy1.ID: ptr.Bool(true)}, time.Now(), false, + context.Background(), host, map[uint]*bool{globalPolicy1.ID: new(true)}, time.Now(), false, nil, ) require.NoError(t, err) } @@ -2927,8 +2927,8 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() { s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) require.Len(t, listHostsResp.Hosts, 0) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false)) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false)) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: new(true)}, time.Now(), false, nil)) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false, nil)) listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID) listHostsResp = listHostsResponse{} @@ -2978,12 +2978,12 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() { // Record query executions require.NoError( t, s.ds.RecordPolicyQueryExecutions( - context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false, + context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: new(true)}, time.Now(), false, nil, ), ) require.NoError( t, s.ds.RecordPolicyQueryExecutions( - context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false, + context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false, nil, ), ) // Update policy stats @@ -3169,8 +3169,8 @@ func (s *integrationTestSuite) TestTeamPoliciesProprietary() { s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) require.Len(t, listHostsResp.Hosts, 0) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false)) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false)) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: new(true)}, time.Now(), false, nil)) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false, nil)) listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?team_id=%d&policy_id=%d&policy_response=passing", team1.ID, policiesResponse.Policies[0].ID) listHostsResp = listHostsResponse{} @@ -3384,6 +3384,7 @@ func (s *integrationTestSuite) TestHostDetailsPolicies() { map[uint]*bool{gpResp.Policy.ID: ptr.Bool(true)}, time.Now(), false, + nil, ) require.NoError(t, err) @@ -5417,7 +5418,7 @@ func (s *integrationTestSuite) TestListHostsByLabel() { require.NotNil(t, gpResp.Policy) require.NoError( t, - s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{gpResp.Policy.ID: ptr.Bool(false)}, time.Now(), false), + s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{gpResp.Policy.ID: new(false)}, time.Now(), false, nil), ) // Add MDM info @@ -9311,7 +9312,7 @@ func (s *integrationTestSuite) TestReenrollHostCleansPolicies() { // create a policy and make the host fail it pol, err := s.ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: t.Name(), Query: "SELECT 1", Platform: host.FleetPlatform()}) require.NoError(t, err) - err = s.ds.RecordPolicyQueryExecutions(ctx, &fleet.Host{ID: host.ID}, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false) + err = s.ds.RecordPolicyQueryExecutions(ctx, &fleet.Host{ID: host.ID}, map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil) require.NoError(t, err) // refetch the host details @@ -10052,7 +10053,7 @@ func (s *integrationTestSuite) TestHostsReportDownload() { // create a policy and make host[1] fail that policy pol, err := s.ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: t.Name(), Query: "SELECT 1"}) require.NoError(t, err) - err = s.ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false) + err = s.ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil) require.NoError(t, err) // create some device mappings for host[2] @@ -12474,20 +12475,20 @@ func (s *integrationTestSuite) TestHostsReportWithPolicyResults() { for i, host := range hosts { // All hosts pass the globalPolicy0 - err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy0.ID: ptr.Bool(true)}, time.Now(), false) + err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy0.ID: new(true)}, time.Now(), false, nil) require.NoError(t, err) if i%2 == 0 { // Half of the hosts pass the globalPolicy1 and fail the globalPolicy2 - err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: ptr.Bool(true)}, time.Now(), false) + err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: new(true)}, time.Now(), false, nil) require.NoError(t, err) - err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: ptr.Bool(false)}, time.Now(), false) + err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: new(false)}, time.Now(), false, nil) require.NoError(t, err) } else { // Half of the hosts pass the globalPolicy2 and fail the globalPolicy1 - err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: ptr.Bool(false)}, time.Now(), false) + err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: new(false)}, time.Now(), false, nil) require.NoError(t, err) - err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: ptr.Bool(true)}, time.Now(), false) + err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: new(true)}, time.Now(), false, nil) require.NoError(t, err) } } @@ -13593,8 +13594,8 @@ func (s *integrationTestSuite) TestHostHealth() { }) require.NoError(t, err) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingPolicy.ID: ptr.Bool(false)}, time.Now(), false)) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingPolicy.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingPolicy.ID: new(false)}, time.Now(), false, nil)) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingPolicy.ID: new(true)}, time.Now(), false, nil)) require.NoError(t, s.ds.SetOrUpdateHostDisksEncryption(context.Background(), host.ID, true)) diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index 975a41e0a8..344d6b0cad 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -3017,7 +3017,7 @@ func (s *integrationEnterpriseTestSuite) TestNoTeamFailingPolicyWebhookTrigger() noTeamPol1.ID: ptr.Bool(false), // Fails and is in webhook config noTeamPol2.ID: ptr.Bool(false), // Fails and is in webhook config noTeamPol3.ID: ptr.Bool(false), // Fails but NOT in webhook config - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) // Initially, OutdatedAutomationBatch should be empty (policies haven't been triggered for automation yet) @@ -4055,7 +4055,7 @@ func (s *integrationEnterpriseTestSuite) TestListDevicePolicies() { // add a policy execution require.NoError(t, s.ds.RecordPolicyQueryExecutions(ctx, host, - map[uint]*bool{gpResp.Policy.ID: ptr.Bool(false)}, time.Now(), false)) + map[uint]*bool{gpResp.Policy.ID: new(false)}, time.Now(), false, nil)) // add a policy to team oldToken := s.token @@ -5303,7 +5303,7 @@ func (s *integrationEnterpriseTestSuite) TestListHosts() { require.NoError( t, s.ds.RecordPolicyQueryExecutions( ctx, host1, - map[uint]*bool{gpResp.Policy.ID: ptr.Bool(false)}, time.Now(), false, + map[uint]*bool{gpResp.Policy.ID: new(false)}, time.Now(), false, nil, ), ) @@ -5568,10 +5568,10 @@ func (s *integrationEnterpriseTestSuite) TestHostHealth() { }) require.NoError(t, err) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingGlobalPolicy.ID: ptr.Bool(false)}, time.Now(), false)) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingGlobalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingTeamPolicy.ID: ptr.Bool(false)}, time.Now(), false)) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingTeamPolicy.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingGlobalPolicy.ID: new(false)}, time.Now(), false, nil)) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingGlobalPolicy.ID: new(true)}, time.Now(), false, nil)) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingTeamPolicy.ID: new(false)}, time.Now(), false, nil)) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingTeamPolicy.ID: new(true)}, time.Now(), false, nil)) hh := getHostHealthResponse{} s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/hosts/%d/health", host.ID), nil, http.StatusOK, &hh) @@ -6288,7 +6288,7 @@ func (s *integrationEnterpriseTestSuite) TestResetAutomation() { createPol1.Policy.ID: ptr.Bool(false), createPol2.Policy.ID: ptr.Bool(false), createPol3.Policy.ID: ptr.Bool(false), // This policy is not activated for automation in config. - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(s.T(), err) pfs, err := s.ds.OutdatedAutomationBatch(ctx) @@ -7417,7 +7417,7 @@ func (s *integrationEnterpriseTestSuite) TestDesktopEndpointWithInvalidPolicy() Critical: false, }) require.NoError(t, err) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{policy.ID: nil}, time.Now(), false)) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{policy.ID: nil}, time.Now(), false, nil)) // Any 'invalid' policies should be ignored. desktopRes := fleetDesktopResponse{} @@ -25216,7 +25216,7 @@ FqU+KJOed6qlzj7qy+u5l6CQeajLGdjUxFlFyw== require.NoError(t, err) // Record a failing result for this policy on the host - err = s.ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(false)}, time.Now(), false) + err = s.ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(false)}, time.Now(), false, nil) require.NoError(t, err) // Bypass should fail with 400 Bad Request @@ -25261,7 +25261,7 @@ FqU+KJOed6qlzj7qy+u5l6CQeajLGdjUxFlFyw== err = s.ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{ caPolicy.ID: ptr.Bool(true), // passing nonCAPolicy.ID: ptr.Bool(false), // failing - }, time.Now(), false) + }, time.Now(), false, nil) require.NoError(t, err) // Bypass must succeed: the only failing policy is not CA-enabled diff --git a/server/service/orbit_test.go b/server/service/orbit_test.go index 15b5379b21..5cc46e8f5a 100644 --- a/server/service/orbit_test.go +++ b/server/service/orbit_test.go @@ -924,7 +924,7 @@ func TestSoftwareInstallReplicaLag(t *testing.T) { opts.RunReplication("software_installers", "software_titles") // Mark policy as failing for the host - err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(false)}, time.Now(), false) + err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(false)}, time.Now(), false, nil) require.NoError(t, err) opts.RunReplication("policy_membership") diff --git a/server/service/osquery.go b/server/service/osquery.go index 263890ec8f..72ad4c3a44 100644 --- a/server/service/osquery.go +++ b/server/service/osquery.go @@ -1209,11 +1209,27 @@ func (svc *Service) SubmitDistributedQueryResults( } if len(policyResults) > 0 { + // Compute flipping policies once for all consumers. This replaces up to 5 individual calls to + // FlippingPoliciesForHost with a single database query. + newFailing, newPassing, err := svc.ds.FlippingPoliciesForHost(ctx, host.ID, policyResults) + if err != nil { + logging.WithErr(ctx, err) + } + // Ensure newPassing is non-nil so RecordPolicyQueryExecutions can distinguish "pre-computed with zero results" + // from "not pre-computed" (nil means compute it yourself). + if newPassing == nil { + newPassing = []uint{} + } + newFailingSet := make(map[uint]struct{}, len(newFailing)) + for _, id := range newFailing { + newFailingSet[id] = struct{}{} + } + if err := processCalendarPolicies(ctx, svc.ds, ac, host, policyResults, svc.logger); err != nil { logging.WithErr(ctx, err) } - if err := svc.processScriptsForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, host.OrbitNodeKey, host.ScriptsEnabled, policyResults); err != nil { + if err := svc.processScriptsForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, host.OrbitNodeKey, host.ScriptsEnabled, policyResults, newFailingSet); err != nil { logging.WithErr(ctx, err) } @@ -1226,18 +1242,18 @@ func (svc *Service) SubmitDistributedQueryResults( if host.Platform == "darwin" && svc.EnterpriseOverrides != nil { // NOTE: if the installers for the policies here are not scoped to the host via labels, we update the policy status here to stop it from showing up as "failed" in the // host details. - if err := svc.processVPPForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, policyResults); err != nil { + if err := svc.processVPPForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, policyResults, newFailingSet); err != nil { logging.WithErr(ctx, err) } } // NOTE: if the installers for the policies here are not scoped to the host via labels, we update the policy status here to stop it from showing up as "failed" in the // host details. - if err := svc.processSoftwareForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, host.OrbitNodeKey, policyResults); err != nil { + if err := svc.processSoftwareForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, host.OrbitNodeKey, policyResults, newFailingSet); err != nil { logging.WithErr(ctx, err) } - // filter policy results for webhooks + // Filter policy results for webhooks using pre-computed flipping sets. var policyIDs []uint if globalPolicyAutomationsEnabled(ac.WebhookSettings, ac.Integrations) { policyIDs = append(policyIDs, ac.WebhookSettings.FailingPoliciesWebhook.PolicyIDs...) @@ -1256,12 +1272,13 @@ func (svc *Service) SubmitDistributedQueryResults( filteredResults := filterPolicyResults(policyResults, policyIDs) if len(filteredResults) > 0 { - if failingPolicies, passingPolicies, err := svc.ds.FlippingPoliciesForHost(ctx, host.ID, filteredResults); err != nil { - logging.WithErr(ctx, err) - } else { + // Filter the pre-computed flipping results to only webhook-enabled policies. + webhookFailing := filterByPolicyIDs(newFailing, filteredResults) + webhookPassing := filterByPolicyIDs(newPassing, filteredResults) + if len(webhookFailing) > 0 || len(webhookPassing) > 0 { // Register the flipped policies on a goroutine to not block the hosts on redis requests. go func() { - if err := svc.registerFlippedPolicies(ctx, host.ID, host.Hostname, host.DisplayName(), failingPolicies, passingPolicies); err != nil { + if err := svc.registerFlippedPolicies(ctx, host.ID, host.Hostname, host.DisplayName(), webhookFailing, webhookPassing); err != nil { logging.WithErr(ctx, err) } }() @@ -1275,12 +1292,12 @@ func (svc *Service) SubmitDistributedQueryResults( // maybe we should impose restrictions between async collection interval // and policy update interval? - if err := svc.task.RecordPolicyQueryExecutions(ctx, host, policyResults, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil { + if err := svc.task.RecordPolicyQueryExecutions(ctx, host, policyResults, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost, newPassing); err != nil { logging.WithErr(ctx, err) } } else if hostWithoutPolicies { // RecordPolicyQueryExecutions called with results=nil will still update the host's policy_updated_at column. - if err := svc.task.RecordPolicyQueryExecutions(ctx, host, nil, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil { + if err := svc.task.RecordPolicyQueryExecutions(ctx, host, nil, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost, []uint{}); err != nil { logging.WithErr(ctx, err) } } @@ -1945,6 +1962,18 @@ func filterPolicyResults(incoming map[uint]*bool, webhookPolicies []uint) map[ui return filtered } +// filterByPolicyIDs returns only the policy IDs from ids that are present in allowedResults and have a non-nil result +// (i.e., the policy actually executed). This matches the behavior of FlippingPoliciesForHost which ignores nil results. +func filterByPolicyIDs(ids []uint, allowedResults map[uint]*bool) []uint { + var filtered []uint + for _, id := range ids { + if val, ok := allowedResults[id]; ok && val != nil { + filtered = append(filtered, id) + } + } + return filtered +} + func (svc *Service) registerFlippedPolicies(ctx context.Context, hostID uint, hostname, displayName string, newFailing, newPassing []uint) error { host := fleet.PolicySetHost{ ID: hostID, @@ -1971,6 +2000,7 @@ func (svc *Service) processSoftwareForNewlyFailingPolicies( hostPlatform string, hostOrbitNodeKey *string, incomingPolicyResults map[uint]*bool, + newFailingSet map[uint]struct{}, ) error { if hostOrbitNodeKey == nil || *hostOrbitNodeKey == "" { // We do not want to queue software installations on vanilla osquery hosts. @@ -1986,15 +2016,13 @@ func (svc *Service) processSoftwareForNewlyFailingPolicies( // Filter out results that are not failures (we are only interested on failing policies, // we don't care about passing policies or policies that failed to execute). - incomingFailingPolicies := make(map[uint]*bool) var incomingFailingPoliciesIDs []uint for policyID, policyResult := range incomingPolicyResults { if policyResult != nil && !*policyResult { - incomingFailingPolicies[policyID] = policyResult incomingFailingPoliciesIDs = append(incomingFailingPoliciesIDs, policyID) } } - if len(incomingFailingPolicies) == 0 { + if len(incomingFailingPoliciesIDs) == 0 { return nil } @@ -2007,44 +2035,16 @@ func (svc *Service) processSoftwareForNewlyFailingPolicies( return nil } - // Filter out results of policies that are not associated to installers. - policiesWithInstallersMap := make(map[uint]fleet.PolicySoftwareInstallerData) - for _, policyWithInstaller := range policiesWithInstaller { - policiesWithInstallersMap[policyWithInstaller.ID] = policyWithInstaller - } - policyResultsOfPoliciesWithInstallers := make(map[uint]*bool) - for policyID, passes := range incomingFailingPolicies { - if _, ok := policiesWithInstallersMap[policyID]; !ok { - continue - } - policyResultsOfPoliciesWithInstallers[policyID] = passes - } - if len(policyResultsOfPoliciesWithInstallers) == 0 { - return nil - } - - // Get the policies associated with installers that are flipping from passing to failing on this host. - policyIDsOfNewlyFailingPoliciesWithInstallers, _, err := svc.ds.FlippingPoliciesForHost( - ctx, hostID, policyResultsOfPoliciesWithInstallers, - ) - if err != nil { - return ctxerr.Wrap(ctx, err, "failed to get flipping policies for host") - } - if len(policyIDsOfNewlyFailingPoliciesWithInstallers) == 0 { - return nil - } - policyIDsOfNewlyFailingPoliciesWithInstallersSet := make(map[uint]struct{}) - for _, policyID := range policyIDsOfNewlyFailingPoliciesWithInstallers { - policyIDsOfNewlyFailingPoliciesWithInstallersSet[policyID] = struct{}{} - } - - // Finally filter out policies with installers that are not newly failing. + // Filter to policies with installers that are newly failing, using the pre-computed set. var failingPoliciesWithInstaller []fleet.PolicySoftwareInstallerData for _, policyWithInstaller := range policiesWithInstaller { - if _, ok := policyIDsOfNewlyFailingPoliciesWithInstallersSet[policyWithInstaller.ID]; ok { + if _, ok := newFailingSet[policyWithInstaller.ID]; ok { failingPoliciesWithInstaller = append(failingPoliciesWithInstaller, policyWithInstaller) } } + if len(failingPoliciesWithInstaller) == 0 { + return nil + } for _, failingPolicyWithInstaller := range failingPoliciesWithInstaller { policyID := failingPolicyWithInstaller.ID @@ -2119,6 +2119,7 @@ func (svc *Service) processVPPForNewlyFailingPolicies( hostTeamID *uint, hostPlatform string, incomingPolicyResults map[uint]*bool, + newFailingSet map[uint]struct{}, ) error { var policyTeamID uint if hostTeamID == nil { @@ -2129,15 +2130,13 @@ func (svc *Service) processVPPForNewlyFailingPolicies( // Filter out results that are not failures (we are only interested on failing policies, // we don't care about passing policies or policies that failed to execute). - incomingFailingPolicies := make(map[uint]*bool) var incomingFailingPoliciesIDs []uint for policyID, policyResult := range incomingPolicyResults { if policyResult != nil && !*policyResult { - incomingFailingPolicies[policyID] = policyResult incomingFailingPoliciesIDs = append(incomingFailingPoliciesIDs, policyID) } } - if len(incomingFailingPolicies) == 0 { + if len(incomingFailingPoliciesIDs) == 0 { return nil } @@ -2150,45 +2149,13 @@ func (svc *Service) processVPPForNewlyFailingPolicies( return nil } - // Filter out results of policies that are not associated to VPP apps. - policiesWithVPPMap := make(map[uint]fleet.PolicyVPPData) - for _, policyWithVPP := range policiesWithVPP { - policiesWithVPPMap[policyWithVPP.ID] = policyWithVPP - } - policyResultsOfPoliciesWithVPP := make(map[uint]*bool) - for policyID, passes := range incomingFailingPolicies { - if _, ok := policiesWithVPPMap[policyID]; !ok { - continue - } - policyResultsOfPoliciesWithVPP[policyID] = passes - } - if len(policyResultsOfPoliciesWithVPP) == 0 { - return nil - } - - // Get the policies associated with VPP apps that are flipping from passing to failing on this host. - policyIDsOfNewlyFailingPoliciesWithVPP, _, err := svc.ds.FlippingPoliciesForHost( - ctx, hostID, policyResultsOfPoliciesWithVPP, - ) - if err != nil { - return ctxerr.Wrap(ctx, err, "failed to get flipping policies for host") - } - if len(policyIDsOfNewlyFailingPoliciesWithVPP) == 0 { - return nil - } - policyIDsOfNewlyFailingPoliciesWithVPPSet := make(map[uint]struct{}) - for _, policyID := range policyIDsOfNewlyFailingPoliciesWithVPP { - policyIDsOfNewlyFailingPoliciesWithVPPSet[policyID] = struct{}{} - } - - // Finally filter out policies with VPP apps that are not newly failing. + // Filter to policies with VPP apps that are newly failing, using the pre-computed set. var failingPoliciesWithVPP []fleet.PolicyVPPData for _, policyWithVPP := range policiesWithVPP { - if _, ok := policyIDsOfNewlyFailingPoliciesWithVPPSet[policyWithVPP.ID]; ok { + if _, ok := newFailingSet[policyWithVPP.ID]; ok { failingPoliciesWithVPP = append(failingPoliciesWithVPP, policyWithVPP) } } - if len(failingPoliciesWithVPP) == 0 { return nil } @@ -2269,6 +2236,7 @@ func (svc *Service) processScriptsForNewlyFailingPolicies( hostOrbitNodeKey *string, hostScriptsEnabled *bool, incomingPolicyResults map[uint]*bool, + newFailingSet map[uint]struct{}, ) error { if hostOrbitNodeKey == nil || *hostOrbitNodeKey == "" { return nil // vanilla osquery hosts can't run scripts @@ -2297,15 +2265,13 @@ func (svc *Service) processScriptsForNewlyFailingPolicies( // Filter out results that are not failures (we are only interested on failing policies, // we don't care about passing policies or policies that failed to execute). - incomingFailingPolicies := make(map[uint]*bool) var incomingFailingPoliciesIDs []uint for policyID, policyResult := range incomingPolicyResults { if policyResult != nil && !*policyResult { - incomingFailingPolicies[policyID] = policyResult incomingFailingPoliciesIDs = append(incomingFailingPoliciesIDs, policyID) } } - if len(incomingFailingPolicies) == 0 { + if len(incomingFailingPoliciesIDs) == 0 { return nil } @@ -2318,44 +2284,16 @@ func (svc *Service) processScriptsForNewlyFailingPolicies( return nil } - // Filter out results of policies that are not associated to scripts. - policiesWithScriptsMap := make(map[uint]fleet.PolicyScriptData) - for _, policyWithScript := range policiesWithScript { - policiesWithScriptsMap[policyWithScript.ID] = policyWithScript - } - policyResultsOfPoliciesWithScripts := make(map[uint]*bool) - for policyID, passes := range incomingFailingPolicies { - if _, ok := policiesWithScriptsMap[policyID]; !ok { - continue - } - policyResultsOfPoliciesWithScripts[policyID] = passes - } - if len(policyResultsOfPoliciesWithScripts) == 0 { - return nil - } - - // Get the policies associated with scripts that are flipping from passing to failing on this host. - policyIDsOfNewlyFailingPoliciesWithScripts, _, err := svc.ds.FlippingPoliciesForHost( - ctx, hostID, policyResultsOfPoliciesWithScripts, - ) - if err != nil { - return ctxerr.Wrap(ctx, err, "failed to get flipping policies for host") - } - if len(policyIDsOfNewlyFailingPoliciesWithScripts) == 0 { - return nil - } - policyIDsOfNewlyFailingPoliciesWithScriptsSet := make(map[uint]struct{}) - for _, policyID := range policyIDsOfNewlyFailingPoliciesWithScripts { - policyIDsOfNewlyFailingPoliciesWithScriptsSet[policyID] = struct{}{} - } - - // Finally filter out policies with scripts that are not newly failing. + // Filter to policies with scripts that are newly failing, using the pre-computed set. var failingPoliciesWithScript []fleet.PolicyScriptData for _, policyWithScript := range policiesWithScript { - if _, ok := policyIDsOfNewlyFailingPoliciesWithScriptsSet[policyWithScript.ID]; ok { + if _, ok := newFailingSet[policyWithScript.ID]; ok { failingPoliciesWithScript = append(failingPoliciesWithScript, policyWithScript) } } + if len(failingPoliciesWithScript) == 0 { + return nil + } for _, failingPolicyWithScript := range failingPoliciesWithScript { policyID := failingPolicyWithScript.ID diff --git a/server/service/osquery_test.go b/server/service/osquery_test.go index e4a476ee60..b067339115 100644 --- a/server/service/osquery_test.go +++ b/server/service/osquery_test.go @@ -3350,7 +3350,7 @@ func TestPolicyQueries(t *testing.T) { } recordedResults := make(map[uint]*bool) ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, gotHost *fleet.Host, results map[uint]*bool, updated time.Time, - deferred bool, + deferred bool, newlyPassingPolicyIDs []uint, ) error { recordedResults = results host = gotHost @@ -3661,7 +3661,7 @@ func TestPolicyWebhooks(t *testing.T) { } recordedResults := make(map[uint]*bool) ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, gotHost *fleet.Host, results map[uint]*bool, updated time.Time, - deferred bool, + deferred bool, newlyPassingPolicyIDs []uint, ) error { recordedResults = results host = gotHost @@ -3696,12 +3696,30 @@ func TestPolicyWebhooks(t *testing.T) { checkPolicyResults(queries) + // Track FlippingPoliciesForHost calls to verify the deduplication optimization: it should be called exactly once + // per SubmitDistributedQueryResults with ALL policy results, not multiple times with subsets. + var flippingCallCount int + var flippingIncomingResults map[uint]*bool ds.FlippingPoliciesForHostFunc = func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error, ) { + flippingCallCount++ + flippingIncomingResults = incomingResults return []uint{3}, nil, nil } + // Track that pre-computed newlyPassingPolicyIDs is forwarded to RecordPolicyQueryExecutions. + var recordedNewlyPassing []uint + ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, gotHost *fleet.Host, results map[uint]*bool, updated time.Time, + deferred bool, newlyPassingPolicyIDs []uint, + ) error { + recordedResults = results + recordedNewlyPassing = newlyPassingPolicyIDs + host = gotHost + return nil + } + + flippingCallCount = 0 // Record a query execution. err = svc.SubmitDistributedQueryResults( ctx, @@ -3726,6 +3744,14 @@ func TestPolicyWebhooks(t *testing.T) { require.NotNil(t, recordedResults[3]) require.False(t, *recordedResults[3]) + // Verify FlippingPoliciesForHost was called exactly once with all 3 policy results. + require.Equal(t, 1, flippingCallCount, "FlippingPoliciesForHost should be called exactly once per SubmitDistributedQueryResults") + require.Len(t, flippingIncomingResults, 3, "FlippingPoliciesForHost should receive all policy results") + // Verify pre-computed newlyPassingPolicyIDs was forwarded (FlippingPoliciesForHost returned nil for newPassing, + // but the caller normalizes it to a non-nil empty slice). + require.NotNil(t, recordedNewlyPassing, "pre-computed newlyPassingPolicyIDs should be forwarded to RecordPolicyQueryExecutions") + require.Empty(t, recordedNewlyPassing) + cmpSets := func(expSets map[uint][]fleet.PolicySetHost) error { actualSets, err := failingPolicySet.ListSets() if err != nil { @@ -3805,9 +3831,12 @@ func TestPolicyWebhooks(t *testing.T) { ds.FlippingPoliciesForHostFunc = func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error, ) { + flippingCallCount++ + flippingIncomingResults = incomingResults return []uint{1}, []uint{3}, nil } + flippingCallCount = 0 // Record another query execution. err = svc.SubmitDistributedQueryResults( ctx, @@ -3832,6 +3861,12 @@ func TestPolicyWebhooks(t *testing.T) { require.NotNil(t, recordedResults[3]) require.True(t, *recordedResults[3]) + // Verify single call and correct forwarding when there are actual flips. + require.Equal(t, 1, flippingCallCount, "FlippingPoliciesForHost should be called exactly once") + require.Len(t, flippingIncomingResults, 3) + require.NotNil(t, recordedNewlyPassing) + require.ElementsMatch(t, []uint{3}, recordedNewlyPassing, "newlyPassingPolicyIDs should be forwarded from FlippingPoliciesForHost") + assert.Eventually(t, func() bool { err = cmpSets(map[uint][]fleet.PolicySetHost{ 1: {{