mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 13:37:30 +00:00
Removed duplicate FlippingPoliciesForHost DB calls (#42845)
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #42836 This is another hot path optimization. ## Before When a host submits policy results via `SubmitDistributedQueryResults`, the system needed to determine which policies "flipped" (changed from passing to failing or vice versa). Each consumer computed this independently: ``` SubmitDistributedQueryResults(policyResults) | +-- processScriptsForNewlyFailingPolicies | filter to failing policies with scripts | BUILD SUBSET of results | CALL FlippingPoliciesForHost(subset) <-- DB query #1 | convert result to set, filter, queue scripts | +-- processSoftwareForNewlyFailingPolicies | filter to failing policies with installers | BUILD SUBSET of results | CALL FlippingPoliciesForHost(subset) <-- DB query #2 | convert result to set, filter, queue installs | +-- processVPPForNewlyFailingPolicies | filter to failing policies with VPP apps | BUILD SUBSET of results | CALL FlippingPoliciesForHost(subset) <-- DB query #3 | convert result to set, filter, queue VPP | +-- webhook filtering | filter to webhook-enabled policies | CALL FlippingPoliciesForHost(subset) <-- DB query #4 | register flipped policies in Redis | +-- RecordPolicyQueryExecutions CALL FlippingPoliciesForHost(all results) <-- DB query #5 reset attempt counters for newly passing INSERT/UPDATE policy_membership ``` Each `FlippingPoliciesForHost` call runs `SELECT policy_id, passes FROM policy_membership WHERE host_id = ? AND policy_id IN (?)`. All 5 queries hit the same table for the same host before `policy_membership` is updated, so they all see identical state. Each consumer also built intermediate maps to narrow down to its subset before calling `FlippingPoliciesForHost`, then converted the result into yet another set for filtering. This meant 3-4 temporary maps per consumer. ## After ``` SubmitDistributedQueryResults(policyResults) | CALL FlippingPoliciesForHost(all results) <-- single DB query build newFailingSet, normalize newPassing | +-- processScriptsForNewlyFailingPolicies | filter to failing policies with scripts | CHECK newFailingSet (in-memory map lookup) | queue scripts | +-- processSoftwareForNewlyFailingPolicies | filter to failing policies with installers | CHECK newFailingSet (in-memory map lookup) | queue installs | +-- processVPPForNewlyFailingPolicies | filter to failing policies with VPP apps | CHECK newFailingSet (in-memory map lookup) | queue VPP | +-- webhook filtering | filter to webhook-enabled policies | FILTER newFailing/newPassing by policy IDs (in-memory) | register flipped policies in Redis | +-- RecordPolicyQueryExecutions USE pre-computed newPassing (skip DB query) reset attempt counters for newly passing INSERT/UPDATE policy_membership ``` The intermediate subset maps and per-consumer set conversions are removed. Each process function goes directly from "policies with associated automation" to "is this policy in newFailingSet?" in a single map lookup. # Checklist for submitter If some of the following don't apply, delete the relevant line. - [x] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. ## Testing - [x] Added/updated automated tests - [x] QA'd all new/changed functionality manually <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Performance Improvements** * Reduced redundant database queries during policy result submissions by computing flipping policies once per host check-in instead of multiple times. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
a7e4557066
commit
8af94af14b
16 changed files with 286 additions and 276 deletions
1
changes/42836-deduplicate-flipping-policies-queries
Normal file
1
changes/42836-deduplicate-flipping-policies-queries
Normal file
|
|
@ -0,0 +1 @@
|
|||
- Reduced redundant database queries during policy result submission by computing flipping policies once per host check-in instead of multiple times.
|
||||
|
|
@ -306,7 +306,7 @@ func testConditionalAccessBypassDeviceWithBlockingPolicy(t *testing.T, ds *Datas
|
|||
require.NoError(t, err)
|
||||
|
||||
// Record a failing result for this policy on the host
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(false)}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Bypass should fail because the host has a failing CA-enabled policy
|
||||
|
|
@ -365,7 +365,7 @@ func testConditionalAccessBypassAllowedWithNonCAFailingCriticalPolicy(t *testing
|
|||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{
|
||||
caPolicy.ID: ptr.Bool(true), // passing
|
||||
nonCAPolicy.ID: ptr.Bool(false), // failing
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Bypass must succeed: the only failing policy is not CA-enabled
|
||||
|
|
@ -415,7 +415,7 @@ func testConditionalAccessBypassAllowedWithCAEnabledNonCriticalPolicy(t *testing
|
|||
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{
|
||||
nonCriticalCAPolicy.ID: ptr.Bool(false), // failing
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Bypass must succeed: policy is CA-enabled but not critical
|
||||
|
|
|
|||
|
|
@ -4101,8 +4101,8 @@ func testHostsListByPolicy(t *testing.T, ds *Datastore) {
|
|||
require.Len(t, hosts, 0)
|
||||
|
||||
// Make one host pass the policy and another not pass
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{1: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{1: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{1: new(true)}, time.Now(), false, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{1: new(false)}, time.Now(), false, nil))
|
||||
|
||||
hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{PolicyIDFilter: &p.ID, PolicyResponseFilter: ptr.Bool(true)}, 1)
|
||||
require.Len(t, hosts, 1)
|
||||
|
|
@ -5105,18 +5105,18 @@ func testHostsListFailingPolicies(t *testing.T, ds *Datastore) {
|
|||
assert.Zero(t, *h2.HostIssues.CriticalVulnerabilitiesCount)
|
||||
assert.Zero(t, h2.HostIssues.TotalIssuesCount)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(true)}, time.Now(), false, nil))
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(false), p2.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(false), p2.ID: new(false)}, time.Now(), false, nil))
|
||||
checkHostIssues(t, ds, hosts, filter, h2.ID, 2)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(true), p2.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(true), p2.ID: new(false)}, time.Now(), false, nil))
|
||||
checkHostIssues(t, ds, hosts, filter, h2.ID, 1)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(true), p2.ID: new(true)}, time.Now(), false, nil))
|
||||
checkHostIssues(t, ds, hosts, filter, h2.ID, 0)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(false)}, time.Now(), false, nil))
|
||||
checkHostIssues(t, ds, hosts, filter, h1.ID, 1)
|
||||
|
||||
checkHostIssuesWithOpts(t, ds, filter, h1.ID, fleet.HostListOptions{DisableIssues: true}, 0)
|
||||
|
|
@ -5183,8 +5183,8 @@ func testHostsReadsLessRows(t *testing.T, ds *Datastore) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(true)}, time.Now(), false, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(false)}, time.Now(), false, nil))
|
||||
|
||||
prevRead := getReads(t, ds)
|
||||
h1WithExtras, err := ds.Host(context.Background(), h1.ID)
|
||||
|
|
@ -8969,7 +8969,7 @@ func testHostsDeleteHosts(t *testing.T, ds *Datastore) {
|
|||
_, err = ds.writer(context.Background()).Exec(`INSERT INTO query_results (host_id, query_id, last_fetched, data) VALUES (?, ?, ?, ?)`, host.ID, policy.ID, time.Now(), `{"foo": "bar"}`)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil))
|
||||
// Update host_mdm.
|
||||
err = ds.SetOrUpdateMDMData(context.Background(), host.ID, false, true, "foo.mdm.example.com", false, "", "", false)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -9727,7 +9727,7 @@ func testFailingPoliciesCount(t *testing.T, ds *Datastore) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
if len(tc.policyEx) != 0 {
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, tc.host, tc.policyEx, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, tc.host, tc.policyEx, time.Now(), false, nil))
|
||||
}
|
||||
actual, err := ds.FailingPoliciesCount(ctx, tc.host)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -9769,7 +9769,7 @@ func testHostsRecordNoPolicies(t *testing.T, ds *Datastore) {
|
|||
assert.Zero(t, h2.HostIssues.TotalIssuesCount)
|
||||
|
||||
policyUpdatedAt := initialTime.Add(1 * time.Hour)
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, nil, policyUpdatedAt, false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, nil, policyUpdatedAt, false, nil))
|
||||
|
||||
hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{}, 2)
|
||||
require.Len(t, hosts, 2)
|
||||
|
|
@ -9916,7 +9916,7 @@ func testHostOrder(t *testing.T, ds *Datastore) {
|
|||
}
|
||||
require.NoError(
|
||||
t, ds.RecordPolicyQueryExecutions(
|
||||
context.Background(), createdHosts[i], results, time.Now(), false,
|
||||
context.Background(), createdHosts[i], results, time.Now(), false, nil,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
|
@ -11791,8 +11791,8 @@ func testHostHealth(t *testing.T, ds *Datastore) {
|
|||
failingPolicy, err := ds.NewGlobalPolicy(context.Background(), &u.ID, fleet.PolicyPayload{QueryID: &q.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h, map[uint]*bool{passingPolicy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h, map[uint]*bool{failingPolicy.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h, map[uint]*bool{passingPolicy.ID: new(true)}, time.Now(), false, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h, map[uint]*bool{failingPolicy.ID: new(false)}, time.Now(), false, nil))
|
||||
|
||||
// set up vulnerable software
|
||||
software := []fleet.Software{
|
||||
|
|
@ -12269,7 +12269,7 @@ func testUpdateHostIssues(t *testing.T, ds *Datastore) {
|
|||
require.NoError(
|
||||
// RecordPolicyQueryExecutions should call UpdateHostIssuesFailingPolicies, so we don't have to
|
||||
t, ds.RecordPolicyQueryExecutions(
|
||||
context.Background(), hosts[i], results, time.Now(), false,
|
||||
context.Background(), hosts[i], results, time.Now(), false, nil,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1604,9 +1604,9 @@ func testListHostsInLabelIssues(t *testing.T, ds *Datastore) {
|
|||
assert.Zero(t, *h2.HostIssues.CriticalVulnerabilitiesCount)
|
||||
assert.Zero(t, h2.HostIssues.TotalIssuesCount)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(true)}, time.Now(), false, nil))
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(false), p2.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(false), p2.ID: new(false)}, time.Now(), false, nil))
|
||||
checkLabelHostIssues(t, ds, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 2, 0)
|
||||
|
||||
// Add a critical vulnerability
|
||||
|
|
@ -1676,13 +1676,13 @@ func testListHostsInLabelIssues(t *testing.T, ds *Datastore) {
|
|||
assert.NoError(t, ds.UpdateHostIssuesVulnerabilities(ctx))
|
||||
checkLabelHostIssues(t, ds, l1.ID, filter, hosts[6].ID, fleet.HostListOptions{}, 0, 4)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(true), p2.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(true), p2.ID: new(false)}, time.Now(), false, nil))
|
||||
checkLabelHostIssues(t, ds, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 1, 1)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: new(true), p2.ID: new(true)}, time.Now(), false, nil))
|
||||
checkLabelHostIssues(t, ds, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 0, 1)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: new(false)}, time.Now(), false, nil))
|
||||
checkLabelHostIssues(t, ds, l1.ID, filter, h1.ID, fleet.HostListOptions{}, 1, 1)
|
||||
|
||||
checkLabelHostIssues(t, ds, l1.ID, filter, h1.ID, fleet.HostListOptions{DisableIssues: true}, 0, 0)
|
||||
|
|
|
|||
|
|
@ -608,12 +608,20 @@ func filterNotExecuted(results map[uint]*bool) map[uint]bool {
|
|||
return filtered
|
||||
}
|
||||
|
||||
func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error {
|
||||
func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool, newlyPassingPolicyIDs []uint) error {
|
||||
// Identify policies that flipped failing -> passing for this host using current incoming results.
|
||||
// We compute this before updating policy_membership so we compare against the previous state.
|
||||
_, newPassing, err := ds.FlippingPoliciesForHost(ctx, host.ID, results)
|
||||
if err != nil {
|
||||
return err
|
||||
// When newlyPassingPolicyIDs is non-nil, the caller has already computed flipping policies
|
||||
// (e.g. SubmitDistributedQueryResults computes it once for all consumers) so we reuse that result.
|
||||
var newPassing []uint
|
||||
if newlyPassingPolicyIDs != nil {
|
||||
newPassing = newlyPassingPolicyIDs
|
||||
} else {
|
||||
var err error
|
||||
_, newPassing, err = ds.FlippingPoliciesForHost(ctx, host.ID, results)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(newPassing) > 0 {
|
||||
slices.Sort(newPassing)
|
||||
|
|
@ -645,7 +653,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *flee
|
|||
// semantically equivalent, even though here it processes a single host and
|
||||
// in async mode it processes a batch of hosts).
|
||||
|
||||
err = ds.withTx(ctx, func(tx sqlx.ExtContext) error {
|
||||
err := ds.withTx(ctx, func(tx sqlx.ExtContext) error {
|
||||
if len(results) > 0 {
|
||||
query := fmt.Sprintf(
|
||||
`INSERT INTO policy_membership (updated_at, policy_id, host_id, passes)
|
||||
|
|
|
|||
|
|
@ -460,14 +460,14 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) {
|
|||
require.NotNil(t, p2.AuthorID)
|
||||
assert.Equal(t, user1.ID, *p2.AuthorID)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: new(true)}, time.Now(), deferred, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: new(true)}, time.Now(), deferred, nil))
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: nil}, time.Now(), deferred))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: nil}, time.Now(), deferred, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: new(false)}, time.Now(), deferred, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: new(true)}, time.Now(), deferred, nil))
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: nil}, time.Now(), deferred))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: nil}, time.Now(), deferred, nil))
|
||||
|
||||
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
|
||||
|
||||
|
|
@ -483,8 +483,8 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) {
|
|||
assert.Equal(t, uint(0), policies[1].PassingHostCount)
|
||||
assert.Equal(t, uint(0), policies[1].FailingHostCount)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: ptr.Bool(false)}, time.Now(), deferred))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: new(false)}, time.Now(), deferred, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: new(false)}, time.Now(), deferred, nil))
|
||||
|
||||
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
|
||||
|
||||
|
|
@ -500,6 +500,30 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) {
|
|||
assert.Equal(t, uint(0), policies[1].PassingHostCount)
|
||||
assert.Equal(t, uint(1), policies[1].FailingHostCount)
|
||||
|
||||
// Test with pre-computed newlyPassingPolicyIDs (non-nil) to exercise the path where
|
||||
// RecordPolicyQueryExecutions skips calling FlippingPoliciesForHost internally.
|
||||
// host1 currently has p.ID=failing, so flipping to passing with pre-computed IDs.
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(
|
||||
ctx, host1, map[uint]*bool{p.ID: new(true)}, time.Now(), deferred, []uint{p.ID},
|
||||
))
|
||||
// Also test with an empty (but non-nil) slice, which means "already computed, no newly passing".
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(
|
||||
ctx, host2, map[uint]*bool{p2.ID: new(true)}, time.Now(), deferred, []uint{},
|
||||
))
|
||||
|
||||
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
|
||||
policies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, policies, 2)
|
||||
// host1: p now passing again (was failing), host2: p still passing
|
||||
assert.Equal(t, p.ID, policies[0].ID)
|
||||
assert.Equal(t, uint(2), policies[0].PassingHostCount)
|
||||
assert.Equal(t, uint(0), policies[0].FailingHostCount)
|
||||
// host2: p2 now passing (was failing)
|
||||
assert.Equal(t, p2.ID, policies[1].ID)
|
||||
assert.Equal(t, uint(1), policies[1].PassingHostCount)
|
||||
assert.Equal(t, uint(0), policies[1].FailingHostCount)
|
||||
|
||||
policy, err := ds.Policy(ctx, policies[0].ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, policies[0], policy)
|
||||
|
|
@ -553,9 +577,9 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// create some policy results
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host3, map[uint]*bool{t1pol.ID: ptr.Bool(true), p.ID: ptr.Bool(true), p2.ID: ptr.Bool(false)}, time.Now(), deferred))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host4, map[uint]*bool{t2pol.ID: ptr.Bool(false), t2pol2.ID: ptr.Bool(true), p.ID: ptr.Bool(false)}, time.Now(), deferred))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host5, map[uint]*bool{t2pol.ID: ptr.Bool(true), t2pol2.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), deferred))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host3, map[uint]*bool{t1pol.ID: new(true), p.ID: new(true), p2.ID: new(false)}, time.Now(), deferred, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host4, map[uint]*bool{t2pol.ID: new(false), t2pol2.ID: new(true), p.ID: new(false)}, time.Now(), deferred, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host5, map[uint]*bool{t2pol.ID: new(true), t2pol2.ID: new(true), p2.ID: new(true)}, time.Now(), deferred, nil))
|
||||
|
||||
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
|
||||
|
||||
|
|
@ -1077,7 +1101,7 @@ func testListMergedTeamPolicies(t *testing.T, ds *Datastore) {
|
|||
&fleet.Host{OsqueryHostID: ptr.String("host1"), NodeKey: ptr.String(fmt.Sprint("host1", 1)), TeamID: nil})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{gpol.ID: ptr.Bool(true)}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{gpol.ID: new(true)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ds.UpdateHostPolicyCounts(context.Background())
|
||||
|
|
@ -1100,7 +1124,7 @@ func testListMergedTeamPolicies(t *testing.T, ds *Datastore) {
|
|||
err = ds.AddHostsToTeam(ctx, fleet.NewAddHostsToTeamParams(&team1.ID, []uint{host.ID}))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{team1policy.ID: ptr.Bool(true)}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{team1policy.ID: new(true)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ds.UpdateHostPolicyCounts(context.Background())
|
||||
|
|
@ -1447,7 +1471,7 @@ func testPolicyQueriesForHost(t *testing.T, ds *Datastore) {
|
|||
assert.Equal(t, q.Query, queries[fmt.Sprint(q.ID)])
|
||||
|
||||
// Team policy ran with failing result.
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{tp.ID: ptr.Bool(false), gp.ID: nil}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{tp.ID: new(false), gp.ID: nil}, time.Now(), false, nil))
|
||||
|
||||
policies, err := ds.ListPoliciesForHost(context.Background(), host1)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -1490,7 +1514,7 @@ func testPolicyQueriesForHost(t *testing.T, ds *Datastore) {
|
|||
assert.Equal(t, "", policies[0].Response)
|
||||
|
||||
// Global policy ran with passing result.
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{gp.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{gp.ID: new(true)}, time.Now(), false, nil))
|
||||
|
||||
policies, err = ds.ListPoliciesForHost(context.Background(), host2)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -1511,7 +1535,7 @@ func testPolicyQueriesForHost(t *testing.T, ds *Datastore) {
|
|||
require.NoError(t, err)
|
||||
require.NoError(t,
|
||||
ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{uint(id): nil}, //nolint:gosec // dismiss G115
|
||||
time.Now(), false))
|
||||
time.Now(), false, nil))
|
||||
|
||||
policies, err = ds.ListPoliciesForHost(context.Background(), host2)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -1556,7 +1580,7 @@ func testPoliciesByID(t *testing.T, ds *Datastore) {
|
|||
err = ds.SavePolicy(context.Background(), policy2, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{policy1.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{policy1.ID: new(true)}, time.Now(), false, nil))
|
||||
require.NoError(t, ds.UpdateHostPolicyCounts(context.Background()))
|
||||
|
||||
policiesByID, err := ds.PoliciesByID(context.Background(), []uint{1, 2})
|
||||
|
|
@ -1634,10 +1658,10 @@ func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: new(false), globalPolicy.ID: new(true)}, time.Now(), false, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: new(true), globalPolicy.ID: new(true)}, time.Now(), false, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: new(false), globalPolicy.ID: new(true)}, time.Now(), false, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: new(true), globalPolicy.ID: new(true)}, time.Now(), false, nil))
|
||||
|
||||
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
|
||||
|
||||
|
|
@ -1691,7 +1715,7 @@ func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
|
|||
checkPassingCount(0, 0, 1, 1)
|
||||
|
||||
// team policies are removed if the host is re-enrolled without a team
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: new(true), globalPolicy.ID: new(true)}, time.Now(), false, nil))
|
||||
checkPassingCount(1, 0, 2, 2)
|
||||
|
||||
// all host policies are removed when a host is re-enrolled
|
||||
|
|
@ -2151,7 +2175,7 @@ func testApplyPolicySpecWithQueryPlatformChanges(t *testing.T, ds *Datastore) {
|
|||
for _, pol := range polsByName {
|
||||
res[pol.ID] = ptr.Bool(false)
|
||||
}
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
for _, h := range globalHosts {
|
||||
|
|
@ -2159,7 +2183,7 @@ func testApplyPolicySpecWithQueryPlatformChanges(t *testing.T, ds *Datastore) {
|
|||
for _, pol := range globalPolsByName {
|
||||
res[pol.ID] = ptr.Bool(false)
|
||||
}
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
err = ds.UpdateHostPolicyCounts(ctx)
|
||||
|
|
@ -2529,17 +2553,17 @@ func testCachedPolicyCountDeletesOnPolicyChange(t *testing.T, ds *Datastore) {
|
|||
// teamHost and globalHost fail all policies
|
||||
require.NoError(
|
||||
t, ds.RecordPolicyQueryExecutions(
|
||||
ctx, teamHost, map[uint]*bool{globalPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(false)}, time.Now(), false,
|
||||
ctx, teamHost, map[uint]*bool{globalPolicy.ID: new(false), globalPolicy.ID: new(false)}, time.Now(), false, nil,
|
||||
),
|
||||
)
|
||||
require.NoError(
|
||||
t, ds.RecordPolicyQueryExecutions(
|
||||
ctx, teamHost, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), teamPolicy.ID: ptr.Bool(false)}, time.Now(), false,
|
||||
ctx, teamHost, map[uint]*bool{teamPolicy.ID: new(false), teamPolicy.ID: new(false)}, time.Now(), false, nil,
|
||||
),
|
||||
)
|
||||
require.NoError(
|
||||
t, ds.RecordPolicyQueryExecutions(
|
||||
ctx, globalHost, map[uint]*bool{globalPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(false)}, time.Now(), false,
|
||||
ctx, globalHost, map[uint]*bool{globalPolicy.ID: new(false), globalPolicy.ID: new(false)}, time.Now(), false, nil,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -2723,7 +2747,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) {
|
|||
require.Empty(t, newPassing) // because this would be the first run.
|
||||
|
||||
// Record the above executions.
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// incoming policy 1 with passing result: no => yes
|
||||
|
|
@ -2738,7 +2762,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) {
|
|||
require.Equal(t, []uint{p1.ID}, newPassing)
|
||||
|
||||
// Record the above executions.
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// incoming policy 1 with passing result: yes => yes
|
||||
|
|
@ -2753,7 +2777,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) {
|
|||
require.Empty(t, newPassing)
|
||||
|
||||
// Record the above executions.
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// incoming policy 1 failed to execute: yes => no
|
||||
|
|
@ -2777,7 +2801,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) {
|
|||
require.Empty(t, newPassing)
|
||||
|
||||
// Record the above executions.
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// incoming pfailed again failed to execute: --- -> ---
|
||||
|
|
@ -2798,7 +2822,7 @@ func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) {
|
|||
require.Empty(t, newPassing)
|
||||
|
||||
// Record the above executions.
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// incoming policy 4 with first new failing result: --- => no
|
||||
|
|
@ -2932,7 +2956,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
|
|||
// also record a result for linux policy
|
||||
res[polsByName["t2"].ID] = ptr.Bool(true)
|
||||
}
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
for i, h := range globalHosts {
|
||||
|
|
@ -2943,7 +2967,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
|
|||
// also record a result for linux policy
|
||||
res[polsByName["g2"].ID] = ptr.Bool(true)
|
||||
}
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
|
@ -3066,9 +3090,9 @@ func testPolicyViolationDays(t *testing.T, ds *Datastore) {
|
|||
require.NoError(t, ds.InitializePolicyViolationDays(ctx)) // sets starting violation count to zero
|
||||
|
||||
// initialize policy statuses: 1 failling, 2 passing
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[0], map[uint]*bool{pol.ID: ptr.Bool(false)}, then, false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[1], map[uint]*bool{pol.ID: ptr.Bool(true)}, then, false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[2], map[uint]*bool{pol.ID: ptr.Bool(true)}, then, false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[0], map[uint]*bool{pol.ID: new(false)}, then, false, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[1], map[uint]*bool{pol.ID: new(true)}, then, false, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[2], map[uint]*bool{pol.ID: new(true)}, then, false, nil))
|
||||
|
||||
// setup db for test: starting counts zero, more than 24h since last updated, one outstanding violation
|
||||
require.NoError(t, setStatsTimestampDB(time.Now().Add(-25*time.Hour)))
|
||||
|
|
@ -3094,7 +3118,7 @@ func testPolicyViolationDays(t *testing.T, ds *Datastore) {
|
|||
// leave counts at zero for next test
|
||||
|
||||
// setup for test: starting count zero, more than 24h since last updated, add second outstanding violation
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil))
|
||||
require.NoError(t, setStatsTimestampDB(time.Now().Add(-25*time.Hour)))
|
||||
require.NoError(t, ds.IncrementPolicyViolationDays(ctx))
|
||||
actual, possible, err = amountPolicyViolationDaysDB(ctx, ds.reader(ctx))
|
||||
|
|
@ -3106,7 +3130,7 @@ func testPolicyViolationDays(t *testing.T, ds *Datastore) {
|
|||
// leave counts at 2 actual and 3 possible for next test
|
||||
|
||||
// setup for test: starting counts at 2 actual and 3 possible, more than 24h since last updated, resolve one outstaning violation
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: new(true)}, time.Now(), false, nil))
|
||||
require.NoError(t, setStatsTimestampDB(time.Now().Add(-25*time.Hour)))
|
||||
require.NoError(t, ds.IncrementPolicyViolationDays(ctx))
|
||||
actual, possible, err = amountPolicyViolationDaysDB(ctx, ds.reader(ctx))
|
||||
|
|
@ -3196,7 +3220,7 @@ func testPolicyCleanupPolicyMembership(t *testing.T, ds *Datastore) {
|
|||
polsByName["p2"].ID: ptr.Bool(true),
|
||||
polsByName["p3"].ID: ptr.Bool(true),
|
||||
}
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.NoError(t, ds.writer(ctx).Get(&count, "select COUNT(*) from host_issues WHERE total_issues_count > 0"))
|
||||
|
|
@ -3315,7 +3339,7 @@ func testDeleteAllPolicyMemberships(t *testing.T, ds *Datastore) {
|
|||
host,
|
||||
map[uint]*bool{globalPolicy.ID: ptr.Bool(false)},
|
||||
time.Now(),
|
||||
false,
|
||||
false, nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -3378,9 +3402,9 @@ func testOutdatedAutomationBatch(t *testing.T, ds *Datastore) {
|
|||
pol2, err := ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: "policy2"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h1, map[uint]*bool{pol1.ID: ptr.Bool(false), pol2.ID: ptr.Bool(true)}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h1, map[uint]*bool{pol1.ID: new(false), pol2.ID: new(true)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h2, map[uint]*bool{pol1.ID: ptr.Bool(false), pol2.ID: ptr.Bool(false)}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h2, map[uint]*bool{pol1.ID: new(false), pol2.ID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
batch, err := ds.OutdatedAutomationBatch(ctx)
|
||||
|
|
@ -3615,7 +3639,7 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) {
|
|||
res := map[uint]*bool{
|
||||
policy.ID: ptr.Bool(true),
|
||||
}
|
||||
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
|
@ -3671,7 +3695,7 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) {
|
|||
res := map[uint]*bool{
|
||||
policy.ID: result,
|
||||
}
|
||||
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
|
@ -3720,7 +3744,7 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) {
|
|||
policy.ID: ptr.Bool(false),
|
||||
policy2.ID: ptr.Bool(true),
|
||||
}
|
||||
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
for _, h := range teamHosts {
|
||||
|
|
@ -3728,7 +3752,7 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) {
|
|||
policy.ID: ptr.Bool(false),
|
||||
policy2.ID: ptr.Bool(true),
|
||||
}
|
||||
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
|
@ -4101,25 +4125,25 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) {
|
|||
err = ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{
|
||||
team1Policy1.ID: ptr.Bool(true),
|
||||
team1Policy2.ID: ptr.Bool(false),
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{
|
||||
team2Policy1.ID: ptr.Bool(false),
|
||||
team2Policy2.ID: ptr.Bool(true),
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host3, map[uint]*bool{
|
||||
team2Policy1.ID: ptr.Bool(true),
|
||||
team2Policy2.ID: ptr.Bool(true),
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host5, map[uint]*bool{
|
||||
team1Policy1.ID: ptr.Bool(false),
|
||||
team1Policy2.ID: ptr.Bool(false),
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
team1Policies, err := ds.GetCalendarPolicies(ctx, team1.ID)
|
||||
|
|
@ -4187,7 +4211,7 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) {
|
|||
err = ds.RecordPolicyQueryExecutions(ctx, host4, map[uint]*bool{
|
||||
team1Policy1.ID: ptr.Bool(false),
|
||||
team1Policy2.ID: ptr.Bool(false),
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
hostsTeam1, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team1.ID, []uint{team1Policies[0].ID}, nil)
|
||||
|
|
@ -4271,7 +4295,7 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) {
|
|||
ctx, host2, map[uint]*bool{
|
||||
team2Policy1.ID: nil,
|
||||
team2Policy2.ID: nil,
|
||||
}, time.Now(), false,
|
||||
}, time.Now(), false, nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -4301,7 +4325,7 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) {
|
|||
err = ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{
|
||||
team2Policy1.ID: ptr.Bool(true),
|
||||
team2Policy2.ID: ptr.Bool(true),
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
hostsTeam2, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team2.ID, []uint{team2Policies[0].ID, team2Policies[1].ID}, nil)
|
||||
|
|
@ -4396,7 +4420,7 @@ func testGetTeamHostsPolicyMembershipsEmailPriority(t *testing.T, ds *Datastore)
|
|||
})
|
||||
require.NoError(t, err)
|
||||
// Make the host fail the calendar policy so it always appears in results.
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{calendarPolicy.ID: ptr.Bool(false)}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{calendarPolicy.ID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
return h
|
||||
}
|
||||
|
|
@ -5158,7 +5182,7 @@ func testApplyPolicySpecWithInstallers(t *testing.T, ds *Datastore) {
|
|||
err = ds.RecordPolicyQueryExecutions(ctx, host1Team1, map[uint]*bool{
|
||||
policy1Team1.ID: ptr.Bool(false),
|
||||
vppPolicy1Team1.ID: ptr.Bool(false),
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
err = ds.UpdateHostPolicyCounts(ctx)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -5417,7 +5441,7 @@ func testApplyPolicySpecWithInstallers(t *testing.T, ds *Datastore) {
|
|||
err = ds.RecordPolicyQueryExecutions(ctx, host1Team1, map[uint]*bool{
|
||||
policy1Team1.ID: ptr.Bool(false),
|
||||
vppPolicy1Team1.ID: ptr.Bool(false),
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
err = ds.UpdateHostPolicyCounts(ctx)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -5604,7 +5628,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) {
|
|||
globalPolicy2.ID: ptr.Bool(false),
|
||||
policy0NoTeam.ID: ptr.Bool(true),
|
||||
policy3NoTeam.ID: ptr.Bool(false),
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Results for host1Team1
|
||||
|
|
@ -5612,7 +5636,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) {
|
|||
globalPolicy1.ID: ptr.Bool(true),
|
||||
globalPolicy2.ID: nil, // failed to execute, e.g. typo on SQL.
|
||||
policy1Team1.ID: ptr.Bool(true),
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Results for host2Team1
|
||||
|
|
@ -5620,7 +5644,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) {
|
|||
globalPolicy1.ID: ptr.Bool(false),
|
||||
globalPolicy2.ID: ptr.Bool(true),
|
||||
policy1Team1.ID: ptr.Bool(false),
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Results for host3Team2
|
||||
|
|
@ -5628,7 +5652,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) {
|
|||
globalPolicy1.ID: ptr.Bool(true),
|
||||
policy2Team2.ID: ptr.Bool(true),
|
||||
policy4Team2.ID: ptr.Bool(false),
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Results for host5NoTeam
|
||||
|
|
@ -5637,7 +5661,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) {
|
|||
globalPolicy2.ID: ptr.Bool(false),
|
||||
policy0NoTeam.ID: ptr.Bool(false),
|
||||
policy3NoTeam.ID: ptr.Bool(false),
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ds.UpdateHostPolicyCounts(ctx)
|
||||
|
|
@ -6155,7 +6179,7 @@ func testClearAutoInstallPolicyStatusForHost(t *testing.T, ds *Datastore) {
|
|||
policy1.ID: ptr.Bool(true),
|
||||
policy2.ID: ptr.Bool(false), // software isn't installed on host, so Fleet should install it
|
||||
policy3.ID: ptr.Bool(false), // software isn't installed on host, so Fleet should install it
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
hostPolicies, err := ds.ListPoliciesForHost(ctx, host)
|
||||
|
|
@ -6378,7 +6402,7 @@ func testPolicyLabelMembershipCleanup(t *testing.T, ds *Datastore) {
|
|||
|
||||
// Record policy results for all hosts
|
||||
for _, h := range []*fleet.Host{hostNoLabels, hostLabel1, hostLabel2, hostLabelBoth} {
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
|
@ -6408,7 +6432,7 @@ func testPolicyLabelMembershipCleanup(t *testing.T, ds *Datastore) {
|
|||
|
||||
// Re-record membership for all hosts to test exclude labels
|
||||
for _, h := range []*fleet.Host{hostNoLabels, hostLabel1, hostLabel2, hostLabelBoth} {
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
wantHostsByPol[policy.Name] = []uint{hostNoLabels.ID, hostLabel1.ID, hostLabel2.ID, hostLabelBoth.ID}
|
||||
|
|
@ -6426,7 +6450,7 @@ func testPolicyLabelMembershipCleanup(t *testing.T, ds *Datastore) {
|
|||
// Test ApplyPolicySpecs with label changes
|
||||
// First, re-record membership for all hosts
|
||||
for _, h := range []*fleet.Host{hostNoLabels, hostLabel1, hostLabel2, hostLabelBoth} {
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
wantHostsByPol[policy.Name] = []uint{hostNoLabels.ID, hostLabel1.ID, hostLabel2.ID, hostLabelBoth.ID}
|
||||
|
|
@ -6464,7 +6488,7 @@ func testPolicyLabelMembershipCleanup(t *testing.T, ds *Datastore) {
|
|||
|
||||
// Record membership for all hosts with label1
|
||||
for _, h := range []*fleet.Host{hostLabel1, hostLabelBoth, hostWinLabel1, hostMacLabel1} {
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy2.ID: ptr.Bool(true)}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy2.ID: new(true)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
|
@ -6551,8 +6575,8 @@ func testDeletePolicyWithSoftwareActivatesNextActivity(t *testing.T, ds *Datasto
|
|||
require.NoError(t, err)
|
||||
|
||||
// record a failing policy for both hosts, would enqueue the install
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostNoTm, map[uint]*bool{policyNoTm.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostTm, map[uint]*bool{policyTm.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostNoTm, map[uint]*bool{policyNoTm.ID: new(false)}, time.Now(), false, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostTm, map[uint]*bool{policyTm.ID: new(false)}, time.Now(), false, nil))
|
||||
|
||||
// simulate the work of "processSoftwareForNewlyFailingPolicies"
|
||||
installUUIDNoTm, err := ds.InsertSoftwareInstallRequest(ctx, hostNoTm.ID, installerIDNoTm,
|
||||
|
|
@ -6645,8 +6669,8 @@ func testDeletePolicyWithScriptActivatesNextActivity(t *testing.T, ds *Datastore
|
|||
|
||||
// record a failing policy for both hosts, would enqueue the associated
|
||||
// scripts
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostNoTm, map[uint]*bool{policyNoTm.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostTm, map[uint]*bool{policyTm.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostNoTm, map[uint]*bool{policyNoTm.ID: new(false)}, time.Now(), false, nil))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostTm, map[uint]*bool{policyTm.ID: new(false)}, time.Now(), false, nil))
|
||||
|
||||
// simulate the work of "processScriptsForNewlyFailingPolicies"
|
||||
hsrPolicyNoTm, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
||||
|
|
@ -6713,7 +6737,7 @@ func testSimultaneousSavePolicy(t *testing.T, ds *Datastore) {
|
|||
for _, policy := range policies {
|
||||
host1Results[policy.ID] = ptr.Bool(true)
|
||||
}
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host1, host1Results, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host1, host1Results, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run simultaneous
|
||||
|
|
@ -6754,7 +6778,7 @@ func testIsPolicyFailing(t *testing.T, ds *Datastore) {
|
|||
|
||||
// Exists with passes = NULL
|
||||
// failing
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: nil}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: nil}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
isFailing, err = ds.IsPolicyFailing(ctx, policy.ID, host.ID)
|
||||
|
|
@ -6763,7 +6787,7 @@ func testIsPolicyFailing(t *testing.T, ds *Datastore) {
|
|||
|
||||
// exists with passes = false
|
||||
// failing
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(false)}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
isFailing, err = ds.IsPolicyFailing(ctx, policy.ID, host.ID)
|
||||
|
|
@ -6772,7 +6796,7 @@ func testIsPolicyFailing(t *testing.T, ds *Datastore) {
|
|||
|
||||
// exists with passes = true
|
||||
// Not failing
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(true)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
isFailing, err = ds.IsPolicyFailing(ctx, policy.ID, host.ID)
|
||||
|
|
@ -6821,7 +6845,7 @@ func testResetAttemptsOnFailingToPassingSync(t *testing.T, ds *Datastore) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// p1 will be failing
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: new(false)}, time.Now(), false, nil))
|
||||
|
||||
// Create rows with attempt_number > 0 and attempt_number IS NULL (pending)
|
||||
// p1 - completed attempt
|
||||
|
|
@ -6852,7 +6876,7 @@ func testResetAttemptsOnFailingToPassingSync(t *testing.T, ds *Datastore) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// p1 is now passing
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: new(true), p2.ID: new(true)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// p1 rows should be reset to 0 (both completed and pending)
|
||||
|
|
@ -6903,7 +6927,7 @@ func testResetAttemptsOnFailingToPassingAsync(t *testing.T, ds *Datastore) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// p1 is failing
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: new(false)}, time.Now(), false, nil))
|
||||
|
||||
// Create rows with attempt_number > 0 and attempt_number IS NULL (pending)
|
||||
// p1 - completed attempt
|
||||
|
|
@ -7139,7 +7163,7 @@ func testBatchedPolicyMembershipCleanup(t *testing.T, ds *Datastore) {
|
|||
|
||||
// Record failing results for all hosts so they all have policy_membership rows and host_issues entries.
|
||||
for _, h := range hosts {
|
||||
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
|
||||
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
|
@ -7228,7 +7252,7 @@ func testBatchedPolicyMembershipCleanupOnPolicyUpdate(t *testing.T, ds *Datastor
|
|||
// Record results for all hosts (simulating results arriving before platform filter applied).
|
||||
allHosts := append([]*fleet.Host{winHost}, linuxHosts...)
|
||||
for _, h := range allHosts {
|
||||
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
|
||||
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
|
@ -7293,7 +7317,7 @@ func testBatchedPolicyMembershipCleanupOnPolicyUpdate(t *testing.T, ds *Datastor
|
|||
// Record policy results for all label-test hosts so policy_membership is populated.
|
||||
labelHosts := append([]*fleet.Host{lblHost}, nonLblHosts...)
|
||||
for _, h := range labelHosts {
|
||||
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{lblPol.ID: ptr.Bool(false)}, time.Now(), false)
|
||||
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{lblPol.ID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
|
@ -7356,7 +7380,7 @@ func testApplyPolicySpecsNeedsFullMembershipCleanupFlag(t *testing.T, ds *Datast
|
|||
hosts[i] = h
|
||||
}
|
||||
for _, h := range hosts {
|
||||
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
|
||||
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
|
@ -7416,7 +7440,7 @@ func testCleanupPolicyMembershipCrashRecovery(t *testing.T, ds *Datastore) {
|
|||
recordResults := func(t *testing.T, hosts []*fleet.Host, polID uint) {
|
||||
t.Helper()
|
||||
for _, h := range hosts {
|
||||
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{polID: ptr.Bool(false)}, time.Now(), false)
|
||||
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{polID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1092,7 +1092,10 @@ type Datastore interface {
|
|||
|
||||
// RecordPolicyQueryExecutions records the execution results of the policies for the given host.
|
||||
// Even if `results` is empty, the host's `policy_updated_at` will be updated.
|
||||
RecordPolicyQueryExecutions(ctx context.Context, host *Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error
|
||||
// If newlyPassingPolicyIDs is non-nil, it contains the IDs of policies that flipped from failing to passing
|
||||
// and is used directly instead of calling FlippingPoliciesForHost internally. This allows callers that have
|
||||
// already computed flipping policies to avoid a redundant database query.
|
||||
RecordPolicyQueryExecutions(ctx context.Context, host *Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool, newlyPassingPolicyIDs []uint) error
|
||||
|
||||
// RecordLabelQueryExecutions saves the results of label queries. The results map is a map of label id -> whether or
|
||||
// not the label matches. The time parameter is the timestamp to save with the query execution.
|
||||
|
|
|
|||
|
|
@ -781,7 +781,7 @@ type UpdateHostRefetchCriticalQueriesUntilFunc func(ctx context.Context, hostID
|
|||
|
||||
type FlippingPoliciesForHostFunc func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error)
|
||||
|
||||
type RecordPolicyQueryExecutionsFunc func(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error
|
||||
type RecordPolicyQueryExecutionsFunc func(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool, newlyPassingPolicyIDs []uint) error
|
||||
|
||||
type RecordLabelQueryExecutionsFunc func(ctx context.Context, host *fleet.Host, results map[uint]*bool, t time.Time, deferredSaveHost bool) error
|
||||
|
||||
|
|
@ -7260,11 +7260,11 @@ func (s *DataStore) FlippingPoliciesForHost(ctx context.Context, hostID uint, in
|
|||
return s.FlippingPoliciesForHostFunc(ctx, hostID, incomingResults)
|
||||
}
|
||||
|
||||
func (s *DataStore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool) error {
|
||||
func (s *DataStore) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, updated time.Time, deferredSaveHost bool, newlyPassingPolicyIDs []uint) error {
|
||||
s.mu.Lock()
|
||||
s.RecordPolicyQueryExecutionsFuncInvoked = true
|
||||
s.mu.Unlock()
|
||||
return s.RecordPolicyQueryExecutionsFunc(ctx, host, results, updated, deferredSaveHost)
|
||||
return s.RecordPolicyQueryExecutionsFunc(ctx, host, results, updated, deferredSaveHost, newlyPassingPolicyIDs)
|
||||
}
|
||||
|
||||
func (s *DataStore) RecordLabelQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, t time.Time, deferredSaveHost bool) error {
|
||||
|
|
|
|||
|
|
@ -28,11 +28,11 @@ const (
|
|||
// redis list will be LTRIM'd if there are more policy IDs than this.
|
||||
var maxRedisPolicyResultsPerHost = 1000
|
||||
|
||||
func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error {
|
||||
func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool, newlyPassingPolicyIDs []uint) error {
|
||||
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
|
||||
if !cfg.Enabled {
|
||||
host.PolicyUpdatedAt = ts
|
||||
return t.datastore.RecordPolicyQueryExecutions(ctx, host, results, ts, deferred)
|
||||
return t.datastore.RecordPolicyQueryExecutions(ctx, host, results, ts, deferred, newlyPassingPolicyIDs)
|
||||
}
|
||||
|
||||
keyList := fmt.Sprintf(policyPassHostKey, host.ID)
|
||||
|
|
|
|||
|
|
@ -321,7 +321,7 @@ func testRecordPolicyQueryExecutionsSync(t *testing.T, ds *mock.Store, pool flee
|
|||
policyReportedAt := task.GetHostPolicyReportedAt(ctx, host)
|
||||
require.True(t, policyReportedAt.Equal(lastYear))
|
||||
|
||||
err := task.RecordPolicyQueryExecutions(ctx, host, results, now, false)
|
||||
err := task.RecordPolicyQueryExecutions(ctx, host, results, now, false, nil)
|
||||
require.NoError(t, err)
|
||||
require.True(t, ds.RecordPolicyQueryExecutionsFuncInvoked)
|
||||
ds.RecordPolicyQueryExecutionsFuncInvoked = false
|
||||
|
|
@ -373,7 +373,7 @@ func testRecordPolicyQueryExecutionsAsync(t *testing.T, ds *mock.Store, pool fle
|
|||
policyReportedAt := task.GetHostPolicyReportedAt(ctx, host)
|
||||
require.True(t, policyReportedAt.Equal(lastYear))
|
||||
|
||||
err := task.RecordPolicyQueryExecutions(ctx, host, results, now, false)
|
||||
err := task.RecordPolicyQueryExecutions(ctx, host, results, now, false, nil)
|
||||
require.NoError(t, err)
|
||||
require.False(t, ds.RecordPolicyQueryExecutionsFuncInvoked)
|
||||
|
||||
|
|
@ -435,7 +435,7 @@ func testRecordPolicyQueryExecutionsNoPoliciesSync(t *testing.T, ds *mock.Store,
|
|||
policyReportedAt := task.GetHostPolicyReportedAt(ctx, host)
|
||||
require.True(t, policyReportedAt.Equal(lastYear))
|
||||
|
||||
err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false)
|
||||
err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false, nil)
|
||||
require.NoError(t, err)
|
||||
require.True(t, ds.RecordPolicyQueryExecutionsFuncInvoked)
|
||||
ds.RecordPolicyQueryExecutionsFuncInvoked = false
|
||||
|
|
@ -485,7 +485,7 @@ func testRecordPolicyQueryExecutionsNoPoliciesAsync(t *testing.T, ds *mock.Store
|
|||
policyReportedAt := task.GetHostPolicyReportedAt(ctx, host)
|
||||
require.True(t, policyReportedAt.Equal(lastYear))
|
||||
|
||||
err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false)
|
||||
err := task.RecordPolicyQueryExecutions(ctx, host, emptyResults, now, false, nil)
|
||||
require.NoError(t, err)
|
||||
require.False(t, ds.RecordPolicyQueryExecutionsFuncInvoked)
|
||||
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ func TestRecord(t *testing.T) {
|
|||
ds.AsyncBatchUpdateLabelTimestampFunc = func(ctx context.Context, ids []uint, ts time.Time) error {
|
||||
return nil
|
||||
}
|
||||
ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error {
|
||||
ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool, newlyPassingPolicyIDs []uint) error {
|
||||
return nil
|
||||
}
|
||||
ds.AsyncBatchInsertPolicyMembershipFunc = func(ctx context.Context, batch []fleet.PolicyMembershipResult) error {
|
||||
|
|
|
|||
|
|
@ -1223,8 +1223,8 @@ func (s *integrationTestSuite) TestGlobalPolicies() {
|
|||
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
|
||||
require.Len(t, listHostsResp.Hosts, 0)
|
||||
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: new(true)}, time.Now(), false, nil))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false, nil))
|
||||
|
||||
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID)
|
||||
listHostsResp = listHostsResponse{}
|
||||
|
|
@ -1844,7 +1844,7 @@ func (s *integrationTestSuite) TestListHosts() {
|
|||
|
||||
require.NoError(
|
||||
t,
|
||||
s.ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{globalPolicy0.ID: ptr.Bool(false)}, time.Now(), false),
|
||||
s.ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{globalPolicy0.ID: new(false)}, time.Now(), false, nil),
|
||||
)
|
||||
|
||||
resp = listHostsResponse{}
|
||||
|
|
@ -2109,7 +2109,7 @@ func (s *integrationTestSuite) TestListHosts() {
|
|||
for _, host := range hosts {
|
||||
// All hosts pass the globalPolicy1
|
||||
err := s.ds.RecordPolicyQueryExecutions(
|
||||
context.Background(), host, map[uint]*bool{globalPolicy1.ID: ptr.Bool(true)}, time.Now(), false,
|
||||
context.Background(), host, map[uint]*bool{globalPolicy1.ID: new(true)}, time.Now(), false, nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
|
@ -2927,8 +2927,8 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
|
|||
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
|
||||
require.Len(t, listHostsResp.Hosts, 0)
|
||||
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: new(true)}, time.Now(), false, nil))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false, nil))
|
||||
|
||||
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID)
|
||||
listHostsResp = listHostsResponse{}
|
||||
|
|
@ -2978,12 +2978,12 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
|
|||
// Record query executions
|
||||
require.NoError(
|
||||
t, s.ds.RecordPolicyQueryExecutions(
|
||||
context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false,
|
||||
context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: new(true)}, time.Now(), false, nil,
|
||||
),
|
||||
)
|
||||
require.NoError(
|
||||
t, s.ds.RecordPolicyQueryExecutions(
|
||||
context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false,
|
||||
context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false, nil,
|
||||
),
|
||||
)
|
||||
// Update policy stats
|
||||
|
|
@ -3169,8 +3169,8 @@ func (s *integrationTestSuite) TestTeamPoliciesProprietary() {
|
|||
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
|
||||
require.Len(t, listHostsResp.Hosts, 0)
|
||||
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: new(true)}, time.Now(), false, nil))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false, nil))
|
||||
|
||||
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?team_id=%d&policy_id=%d&policy_response=passing", team1.ID, policiesResponse.Policies[0].ID)
|
||||
listHostsResp = listHostsResponse{}
|
||||
|
|
@ -3384,6 +3384,7 @@ func (s *integrationTestSuite) TestHostDetailsPolicies() {
|
|||
map[uint]*bool{gpResp.Policy.ID: ptr.Bool(true)},
|
||||
time.Now(),
|
||||
false,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -5417,7 +5418,7 @@ func (s *integrationTestSuite) TestListHostsByLabel() {
|
|||
require.NotNil(t, gpResp.Policy)
|
||||
require.NoError(
|
||||
t,
|
||||
s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{gpResp.Policy.ID: ptr.Bool(false)}, time.Now(), false),
|
||||
s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{gpResp.Policy.ID: new(false)}, time.Now(), false, nil),
|
||||
)
|
||||
|
||||
// Add MDM info
|
||||
|
|
@ -9311,7 +9312,7 @@ func (s *integrationTestSuite) TestReenrollHostCleansPolicies() {
|
|||
// create a policy and make the host fail it
|
||||
pol, err := s.ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: t.Name(), Query: "SELECT 1", Platform: host.FleetPlatform()})
|
||||
require.NoError(t, err)
|
||||
err = s.ds.RecordPolicyQueryExecutions(ctx, &fleet.Host{ID: host.ID}, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
|
||||
err = s.ds.RecordPolicyQueryExecutions(ctx, &fleet.Host{ID: host.ID}, map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// refetch the host details
|
||||
|
|
@ -10052,7 +10053,7 @@ func (s *integrationTestSuite) TestHostsReportDownload() {
|
|||
// create a policy and make host[1] fail that policy
|
||||
pol, err := s.ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: t.Name(), Query: "SELECT 1"})
|
||||
require.NoError(t, err)
|
||||
err = s.ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
|
||||
err = s.ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create some device mappings for host[2]
|
||||
|
|
@ -12474,20 +12475,20 @@ func (s *integrationTestSuite) TestHostsReportWithPolicyResults() {
|
|||
|
||||
for i, host := range hosts {
|
||||
// All hosts pass the globalPolicy0
|
||||
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy0.ID: ptr.Bool(true)}, time.Now(), false)
|
||||
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy0.ID: new(true)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
if i%2 == 0 {
|
||||
// Half of the hosts pass the globalPolicy1 and fail the globalPolicy2
|
||||
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: ptr.Bool(true)}, time.Now(), false)
|
||||
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: new(true)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: ptr.Bool(false)}, time.Now(), false)
|
||||
err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
// Half of the hosts pass the globalPolicy2 and fail the globalPolicy1
|
||||
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: ptr.Bool(false)}, time.Now(), false)
|
||||
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: ptr.Bool(true)}, time.Now(), false)
|
||||
err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: new(true)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
|
@ -13593,8 +13594,8 @@ func (s *integrationTestSuite) TestHostHealth() {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingPolicy.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingPolicy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingPolicy.ID: new(false)}, time.Now(), false, nil))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingPolicy.ID: new(true)}, time.Now(), false, nil))
|
||||
|
||||
require.NoError(t, s.ds.SetOrUpdateHostDisksEncryption(context.Background(), host.ID, true))
|
||||
|
||||
|
|
|
|||
|
|
@ -3017,7 +3017,7 @@ func (s *integrationEnterpriseTestSuite) TestNoTeamFailingPolicyWebhookTrigger()
|
|||
noTeamPol1.ID: ptr.Bool(false), // Fails and is in webhook config
|
||||
noTeamPol2.ID: ptr.Bool(false), // Fails and is in webhook config
|
||||
noTeamPol3.ID: ptr.Bool(false), // Fails but NOT in webhook config
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially, OutdatedAutomationBatch should be empty (policies haven't been triggered for automation yet)
|
||||
|
|
@ -4055,7 +4055,7 @@ func (s *integrationEnterpriseTestSuite) TestListDevicePolicies() {
|
|||
|
||||
// add a policy execution
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(ctx, host,
|
||||
map[uint]*bool{gpResp.Policy.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
map[uint]*bool{gpResp.Policy.ID: new(false)}, time.Now(), false, nil))
|
||||
|
||||
// add a policy to team
|
||||
oldToken := s.token
|
||||
|
|
@ -5303,7 +5303,7 @@ func (s *integrationEnterpriseTestSuite) TestListHosts() {
|
|||
require.NoError(
|
||||
t, s.ds.RecordPolicyQueryExecutions(
|
||||
ctx, host1,
|
||||
map[uint]*bool{gpResp.Policy.ID: ptr.Bool(false)}, time.Now(), false,
|
||||
map[uint]*bool{gpResp.Policy.ID: new(false)}, time.Now(), false, nil,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -5568,10 +5568,10 @@ func (s *integrationEnterpriseTestSuite) TestHostHealth() {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingGlobalPolicy.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingGlobalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingTeamPolicy.ID: ptr.Bool(false)}, time.Now(), false))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingTeamPolicy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingGlobalPolicy.ID: new(false)}, time.Now(), false, nil))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingGlobalPolicy.ID: new(true)}, time.Now(), false, nil))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{failingTeamPolicy.ID: new(false)}, time.Now(), false, nil))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{passingTeamPolicy.ID: new(true)}, time.Now(), false, nil))
|
||||
|
||||
hh := getHostHealthResponse{}
|
||||
s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/hosts/%d/health", host.ID), nil, http.StatusOK, &hh)
|
||||
|
|
@ -6288,7 +6288,7 @@ func (s *integrationEnterpriseTestSuite) TestResetAutomation() {
|
|||
createPol1.Policy.ID: ptr.Bool(false),
|
||||
createPol2.Policy.ID: ptr.Bool(false),
|
||||
createPol3.Policy.ID: ptr.Bool(false), // This policy is not activated for automation in config.
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
pfs, err := s.ds.OutdatedAutomationBatch(ctx)
|
||||
|
|
@ -7417,7 +7417,7 @@ func (s *integrationEnterpriseTestSuite) TestDesktopEndpointWithInvalidPolicy()
|
|||
Critical: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{policy.ID: nil}, time.Now(), false))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{policy.ID: nil}, time.Now(), false, nil))
|
||||
|
||||
// Any 'invalid' policies should be ignored.
|
||||
desktopRes := fleetDesktopResponse{}
|
||||
|
|
@ -25216,7 +25216,7 @@ FqU+KJOed6qlzj7qy+u5l6CQeajLGdjUxFlFyw==
|
|||
require.NoError(t, err)
|
||||
|
||||
// Record a failing result for this policy on the host
|
||||
err = s.ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(false)}, time.Now(), false)
|
||||
err = s.ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Bypass should fail with 400 Bad Request
|
||||
|
|
@ -25261,7 +25261,7 @@ FqU+KJOed6qlzj7qy+u5l6CQeajLGdjUxFlFyw==
|
|||
err = s.ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{
|
||||
caPolicy.ID: ptr.Bool(true), // passing
|
||||
nonCAPolicy.ID: ptr.Bool(false), // failing
|
||||
}, time.Now(), false)
|
||||
}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Bypass must succeed: the only failing policy is not CA-enabled
|
||||
|
|
|
|||
|
|
@ -924,7 +924,7 @@ func TestSoftwareInstallReplicaLag(t *testing.T) {
|
|||
opts.RunReplication("software_installers", "software_titles")
|
||||
|
||||
// Mark policy as failing for the host
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(false)}, time.Now(), false)
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(false)}, time.Now(), false, nil)
|
||||
require.NoError(t, err)
|
||||
opts.RunReplication("policy_membership")
|
||||
|
||||
|
|
|
|||
|
|
@ -1209,11 +1209,27 @@ func (svc *Service) SubmitDistributedQueryResults(
|
|||
}
|
||||
|
||||
if len(policyResults) > 0 {
|
||||
// Compute flipping policies once for all consumers. This replaces up to 5 individual calls to
|
||||
// FlippingPoliciesForHost with a single database query.
|
||||
newFailing, newPassing, err := svc.ds.FlippingPoliciesForHost(ctx, host.ID, policyResults)
|
||||
if err != nil {
|
||||
logging.WithErr(ctx, err)
|
||||
}
|
||||
// Ensure newPassing is non-nil so RecordPolicyQueryExecutions can distinguish "pre-computed with zero results"
|
||||
// from "not pre-computed" (nil means compute it yourself).
|
||||
if newPassing == nil {
|
||||
newPassing = []uint{}
|
||||
}
|
||||
newFailingSet := make(map[uint]struct{}, len(newFailing))
|
||||
for _, id := range newFailing {
|
||||
newFailingSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
if err := processCalendarPolicies(ctx, svc.ds, ac, host, policyResults, svc.logger); err != nil {
|
||||
logging.WithErr(ctx, err)
|
||||
}
|
||||
|
||||
if err := svc.processScriptsForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, host.OrbitNodeKey, host.ScriptsEnabled, policyResults); err != nil {
|
||||
if err := svc.processScriptsForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, host.OrbitNodeKey, host.ScriptsEnabled, policyResults, newFailingSet); err != nil {
|
||||
logging.WithErr(ctx, err)
|
||||
}
|
||||
|
||||
|
|
@ -1226,18 +1242,18 @@ func (svc *Service) SubmitDistributedQueryResults(
|
|||
if host.Platform == "darwin" && svc.EnterpriseOverrides != nil {
|
||||
// NOTE: if the installers for the policies here are not scoped to the host via labels, we update the policy status here to stop it from showing up as "failed" in the
|
||||
// host details.
|
||||
if err := svc.processVPPForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, policyResults); err != nil {
|
||||
if err := svc.processVPPForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, policyResults, newFailingSet); err != nil {
|
||||
logging.WithErr(ctx, err)
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: if the installers for the policies here are not scoped to the host via labels, we update the policy status here to stop it from showing up as "failed" in the
|
||||
// host details.
|
||||
if err := svc.processSoftwareForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, host.OrbitNodeKey, policyResults); err != nil {
|
||||
if err := svc.processSoftwareForNewlyFailingPolicies(ctx, host.ID, host.TeamID, host.Platform, host.OrbitNodeKey, policyResults, newFailingSet); err != nil {
|
||||
logging.WithErr(ctx, err)
|
||||
}
|
||||
|
||||
// filter policy results for webhooks
|
||||
// Filter policy results for webhooks using pre-computed flipping sets.
|
||||
var policyIDs []uint
|
||||
if globalPolicyAutomationsEnabled(ac.WebhookSettings, ac.Integrations) {
|
||||
policyIDs = append(policyIDs, ac.WebhookSettings.FailingPoliciesWebhook.PolicyIDs...)
|
||||
|
|
@ -1256,12 +1272,13 @@ func (svc *Service) SubmitDistributedQueryResults(
|
|||
|
||||
filteredResults := filterPolicyResults(policyResults, policyIDs)
|
||||
if len(filteredResults) > 0 {
|
||||
if failingPolicies, passingPolicies, err := svc.ds.FlippingPoliciesForHost(ctx, host.ID, filteredResults); err != nil {
|
||||
logging.WithErr(ctx, err)
|
||||
} else {
|
||||
// Filter the pre-computed flipping results to only webhook-enabled policies.
|
||||
webhookFailing := filterByPolicyIDs(newFailing, filteredResults)
|
||||
webhookPassing := filterByPolicyIDs(newPassing, filteredResults)
|
||||
if len(webhookFailing) > 0 || len(webhookPassing) > 0 {
|
||||
// Register the flipped policies on a goroutine to not block the hosts on redis requests.
|
||||
go func() {
|
||||
if err := svc.registerFlippedPolicies(ctx, host.ID, host.Hostname, host.DisplayName(), failingPolicies, passingPolicies); err != nil {
|
||||
if err := svc.registerFlippedPolicies(ctx, host.ID, host.Hostname, host.DisplayName(), webhookFailing, webhookPassing); err != nil {
|
||||
logging.WithErr(ctx, err)
|
||||
}
|
||||
}()
|
||||
|
|
@ -1275,12 +1292,12 @@ func (svc *Service) SubmitDistributedQueryResults(
|
|||
// maybe we should impose restrictions between async collection interval
|
||||
// and policy update interval?
|
||||
|
||||
if err := svc.task.RecordPolicyQueryExecutions(ctx, host, policyResults, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil {
|
||||
if err := svc.task.RecordPolicyQueryExecutions(ctx, host, policyResults, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost, newPassing); err != nil {
|
||||
logging.WithErr(ctx, err)
|
||||
}
|
||||
} else if hostWithoutPolicies {
|
||||
// RecordPolicyQueryExecutions called with results=nil will still update the host's policy_updated_at column.
|
||||
if err := svc.task.RecordPolicyQueryExecutions(ctx, host, nil, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost); err != nil {
|
||||
if err := svc.task.RecordPolicyQueryExecutions(ctx, host, nil, svc.clock.Now(), ac.ServerSettings.DeferredSaveHost, []uint{}); err != nil {
|
||||
logging.WithErr(ctx, err)
|
||||
}
|
||||
}
|
||||
|
|
@ -1945,6 +1962,18 @@ func filterPolicyResults(incoming map[uint]*bool, webhookPolicies []uint) map[ui
|
|||
return filtered
|
||||
}
|
||||
|
||||
// filterByPolicyIDs returns only the policy IDs from ids that are present in allowedResults and have a non-nil result
|
||||
// (i.e., the policy actually executed). This matches the behavior of FlippingPoliciesForHost which ignores nil results.
|
||||
func filterByPolicyIDs(ids []uint, allowedResults map[uint]*bool) []uint {
|
||||
var filtered []uint
|
||||
for _, id := range ids {
|
||||
if val, ok := allowedResults[id]; ok && val != nil {
|
||||
filtered = append(filtered, id)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (svc *Service) registerFlippedPolicies(ctx context.Context, hostID uint, hostname, displayName string, newFailing, newPassing []uint) error {
|
||||
host := fleet.PolicySetHost{
|
||||
ID: hostID,
|
||||
|
|
@ -1971,6 +2000,7 @@ func (svc *Service) processSoftwareForNewlyFailingPolicies(
|
|||
hostPlatform string,
|
||||
hostOrbitNodeKey *string,
|
||||
incomingPolicyResults map[uint]*bool,
|
||||
newFailingSet map[uint]struct{},
|
||||
) error {
|
||||
if hostOrbitNodeKey == nil || *hostOrbitNodeKey == "" {
|
||||
// We do not want to queue software installations on vanilla osquery hosts.
|
||||
|
|
@ -1986,15 +2016,13 @@ func (svc *Service) processSoftwareForNewlyFailingPolicies(
|
|||
|
||||
// Filter out results that are not failures (we are only interested on failing policies,
|
||||
// we don't care about passing policies or policies that failed to execute).
|
||||
incomingFailingPolicies := make(map[uint]*bool)
|
||||
var incomingFailingPoliciesIDs []uint
|
||||
for policyID, policyResult := range incomingPolicyResults {
|
||||
if policyResult != nil && !*policyResult {
|
||||
incomingFailingPolicies[policyID] = policyResult
|
||||
incomingFailingPoliciesIDs = append(incomingFailingPoliciesIDs, policyID)
|
||||
}
|
||||
}
|
||||
if len(incomingFailingPolicies) == 0 {
|
||||
if len(incomingFailingPoliciesIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -2007,44 +2035,16 @@ func (svc *Service) processSoftwareForNewlyFailingPolicies(
|
|||
return nil
|
||||
}
|
||||
|
||||
// Filter out results of policies that are not associated to installers.
|
||||
policiesWithInstallersMap := make(map[uint]fleet.PolicySoftwareInstallerData)
|
||||
for _, policyWithInstaller := range policiesWithInstaller {
|
||||
policiesWithInstallersMap[policyWithInstaller.ID] = policyWithInstaller
|
||||
}
|
||||
policyResultsOfPoliciesWithInstallers := make(map[uint]*bool)
|
||||
for policyID, passes := range incomingFailingPolicies {
|
||||
if _, ok := policiesWithInstallersMap[policyID]; !ok {
|
||||
continue
|
||||
}
|
||||
policyResultsOfPoliciesWithInstallers[policyID] = passes
|
||||
}
|
||||
if len(policyResultsOfPoliciesWithInstallers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the policies associated with installers that are flipping from passing to failing on this host.
|
||||
policyIDsOfNewlyFailingPoliciesWithInstallers, _, err := svc.ds.FlippingPoliciesForHost(
|
||||
ctx, hostID, policyResultsOfPoliciesWithInstallers,
|
||||
)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "failed to get flipping policies for host")
|
||||
}
|
||||
if len(policyIDsOfNewlyFailingPoliciesWithInstallers) == 0 {
|
||||
return nil
|
||||
}
|
||||
policyIDsOfNewlyFailingPoliciesWithInstallersSet := make(map[uint]struct{})
|
||||
for _, policyID := range policyIDsOfNewlyFailingPoliciesWithInstallers {
|
||||
policyIDsOfNewlyFailingPoliciesWithInstallersSet[policyID] = struct{}{}
|
||||
}
|
||||
|
||||
// Finally filter out policies with installers that are not newly failing.
|
||||
// Filter to policies with installers that are newly failing, using the pre-computed set.
|
||||
var failingPoliciesWithInstaller []fleet.PolicySoftwareInstallerData
|
||||
for _, policyWithInstaller := range policiesWithInstaller {
|
||||
if _, ok := policyIDsOfNewlyFailingPoliciesWithInstallersSet[policyWithInstaller.ID]; ok {
|
||||
if _, ok := newFailingSet[policyWithInstaller.ID]; ok {
|
||||
failingPoliciesWithInstaller = append(failingPoliciesWithInstaller, policyWithInstaller)
|
||||
}
|
||||
}
|
||||
if len(failingPoliciesWithInstaller) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, failingPolicyWithInstaller := range failingPoliciesWithInstaller {
|
||||
policyID := failingPolicyWithInstaller.ID
|
||||
|
|
@ -2119,6 +2119,7 @@ func (svc *Service) processVPPForNewlyFailingPolicies(
|
|||
hostTeamID *uint,
|
||||
hostPlatform string,
|
||||
incomingPolicyResults map[uint]*bool,
|
||||
newFailingSet map[uint]struct{},
|
||||
) error {
|
||||
var policyTeamID uint
|
||||
if hostTeamID == nil {
|
||||
|
|
@ -2129,15 +2130,13 @@ func (svc *Service) processVPPForNewlyFailingPolicies(
|
|||
|
||||
// Filter out results that are not failures (we are only interested on failing policies,
|
||||
// we don't care about passing policies or policies that failed to execute).
|
||||
incomingFailingPolicies := make(map[uint]*bool)
|
||||
var incomingFailingPoliciesIDs []uint
|
||||
for policyID, policyResult := range incomingPolicyResults {
|
||||
if policyResult != nil && !*policyResult {
|
||||
incomingFailingPolicies[policyID] = policyResult
|
||||
incomingFailingPoliciesIDs = append(incomingFailingPoliciesIDs, policyID)
|
||||
}
|
||||
}
|
||||
if len(incomingFailingPolicies) == 0 {
|
||||
if len(incomingFailingPoliciesIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -2150,45 +2149,13 @@ func (svc *Service) processVPPForNewlyFailingPolicies(
|
|||
return nil
|
||||
}
|
||||
|
||||
// Filter out results of policies that are not associated to VPP apps.
|
||||
policiesWithVPPMap := make(map[uint]fleet.PolicyVPPData)
|
||||
for _, policyWithVPP := range policiesWithVPP {
|
||||
policiesWithVPPMap[policyWithVPP.ID] = policyWithVPP
|
||||
}
|
||||
policyResultsOfPoliciesWithVPP := make(map[uint]*bool)
|
||||
for policyID, passes := range incomingFailingPolicies {
|
||||
if _, ok := policiesWithVPPMap[policyID]; !ok {
|
||||
continue
|
||||
}
|
||||
policyResultsOfPoliciesWithVPP[policyID] = passes
|
||||
}
|
||||
if len(policyResultsOfPoliciesWithVPP) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the policies associated with VPP apps that are flipping from passing to failing on this host.
|
||||
policyIDsOfNewlyFailingPoliciesWithVPP, _, err := svc.ds.FlippingPoliciesForHost(
|
||||
ctx, hostID, policyResultsOfPoliciesWithVPP,
|
||||
)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "failed to get flipping policies for host")
|
||||
}
|
||||
if len(policyIDsOfNewlyFailingPoliciesWithVPP) == 0 {
|
||||
return nil
|
||||
}
|
||||
policyIDsOfNewlyFailingPoliciesWithVPPSet := make(map[uint]struct{})
|
||||
for _, policyID := range policyIDsOfNewlyFailingPoliciesWithVPP {
|
||||
policyIDsOfNewlyFailingPoliciesWithVPPSet[policyID] = struct{}{}
|
||||
}
|
||||
|
||||
// Finally filter out policies with VPP apps that are not newly failing.
|
||||
// Filter to policies with VPP apps that are newly failing, using the pre-computed set.
|
||||
var failingPoliciesWithVPP []fleet.PolicyVPPData
|
||||
for _, policyWithVPP := range policiesWithVPP {
|
||||
if _, ok := policyIDsOfNewlyFailingPoliciesWithVPPSet[policyWithVPP.ID]; ok {
|
||||
if _, ok := newFailingSet[policyWithVPP.ID]; ok {
|
||||
failingPoliciesWithVPP = append(failingPoliciesWithVPP, policyWithVPP)
|
||||
}
|
||||
}
|
||||
|
||||
if len(failingPoliciesWithVPP) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -2269,6 +2236,7 @@ func (svc *Service) processScriptsForNewlyFailingPolicies(
|
|||
hostOrbitNodeKey *string,
|
||||
hostScriptsEnabled *bool,
|
||||
incomingPolicyResults map[uint]*bool,
|
||||
newFailingSet map[uint]struct{},
|
||||
) error {
|
||||
if hostOrbitNodeKey == nil || *hostOrbitNodeKey == "" {
|
||||
return nil // vanilla osquery hosts can't run scripts
|
||||
|
|
@ -2297,15 +2265,13 @@ func (svc *Service) processScriptsForNewlyFailingPolicies(
|
|||
|
||||
// Filter out results that are not failures (we are only interested on failing policies,
|
||||
// we don't care about passing policies or policies that failed to execute).
|
||||
incomingFailingPolicies := make(map[uint]*bool)
|
||||
var incomingFailingPoliciesIDs []uint
|
||||
for policyID, policyResult := range incomingPolicyResults {
|
||||
if policyResult != nil && !*policyResult {
|
||||
incomingFailingPolicies[policyID] = policyResult
|
||||
incomingFailingPoliciesIDs = append(incomingFailingPoliciesIDs, policyID)
|
||||
}
|
||||
}
|
||||
if len(incomingFailingPolicies) == 0 {
|
||||
if len(incomingFailingPoliciesIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -2318,44 +2284,16 @@ func (svc *Service) processScriptsForNewlyFailingPolicies(
|
|||
return nil
|
||||
}
|
||||
|
||||
// Filter out results of policies that are not associated to scripts.
|
||||
policiesWithScriptsMap := make(map[uint]fleet.PolicyScriptData)
|
||||
for _, policyWithScript := range policiesWithScript {
|
||||
policiesWithScriptsMap[policyWithScript.ID] = policyWithScript
|
||||
}
|
||||
policyResultsOfPoliciesWithScripts := make(map[uint]*bool)
|
||||
for policyID, passes := range incomingFailingPolicies {
|
||||
if _, ok := policiesWithScriptsMap[policyID]; !ok {
|
||||
continue
|
||||
}
|
||||
policyResultsOfPoliciesWithScripts[policyID] = passes
|
||||
}
|
||||
if len(policyResultsOfPoliciesWithScripts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the policies associated with scripts that are flipping from passing to failing on this host.
|
||||
policyIDsOfNewlyFailingPoliciesWithScripts, _, err := svc.ds.FlippingPoliciesForHost(
|
||||
ctx, hostID, policyResultsOfPoliciesWithScripts,
|
||||
)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "failed to get flipping policies for host")
|
||||
}
|
||||
if len(policyIDsOfNewlyFailingPoliciesWithScripts) == 0 {
|
||||
return nil
|
||||
}
|
||||
policyIDsOfNewlyFailingPoliciesWithScriptsSet := make(map[uint]struct{})
|
||||
for _, policyID := range policyIDsOfNewlyFailingPoliciesWithScripts {
|
||||
policyIDsOfNewlyFailingPoliciesWithScriptsSet[policyID] = struct{}{}
|
||||
}
|
||||
|
||||
// Finally filter out policies with scripts that are not newly failing.
|
||||
// Filter to policies with scripts that are newly failing, using the pre-computed set.
|
||||
var failingPoliciesWithScript []fleet.PolicyScriptData
|
||||
for _, policyWithScript := range policiesWithScript {
|
||||
if _, ok := policyIDsOfNewlyFailingPoliciesWithScriptsSet[policyWithScript.ID]; ok {
|
||||
if _, ok := newFailingSet[policyWithScript.ID]; ok {
|
||||
failingPoliciesWithScript = append(failingPoliciesWithScript, policyWithScript)
|
||||
}
|
||||
}
|
||||
if len(failingPoliciesWithScript) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, failingPolicyWithScript := range failingPoliciesWithScript {
|
||||
policyID := failingPolicyWithScript.ID
|
||||
|
|
|
|||
|
|
@ -3350,7 +3350,7 @@ func TestPolicyQueries(t *testing.T) {
|
|||
}
|
||||
recordedResults := make(map[uint]*bool)
|
||||
ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, gotHost *fleet.Host, results map[uint]*bool, updated time.Time,
|
||||
deferred bool,
|
||||
deferred bool, newlyPassingPolicyIDs []uint,
|
||||
) error {
|
||||
recordedResults = results
|
||||
host = gotHost
|
||||
|
|
@ -3661,7 +3661,7 @@ func TestPolicyWebhooks(t *testing.T) {
|
|||
}
|
||||
recordedResults := make(map[uint]*bool)
|
||||
ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, gotHost *fleet.Host, results map[uint]*bool, updated time.Time,
|
||||
deferred bool,
|
||||
deferred bool, newlyPassingPolicyIDs []uint,
|
||||
) error {
|
||||
recordedResults = results
|
||||
host = gotHost
|
||||
|
|
@ -3696,12 +3696,30 @@ func TestPolicyWebhooks(t *testing.T) {
|
|||
|
||||
checkPolicyResults(queries)
|
||||
|
||||
// Track FlippingPoliciesForHost calls to verify the deduplication optimization: it should be called exactly once
|
||||
// per SubmitDistributedQueryResults with ALL policy results, not multiple times with subsets.
|
||||
var flippingCallCount int
|
||||
var flippingIncomingResults map[uint]*bool
|
||||
ds.FlippingPoliciesForHostFunc = func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint,
|
||||
err error,
|
||||
) {
|
||||
flippingCallCount++
|
||||
flippingIncomingResults = incomingResults
|
||||
return []uint{3}, nil, nil
|
||||
}
|
||||
|
||||
// Track that pre-computed newlyPassingPolicyIDs is forwarded to RecordPolicyQueryExecutions.
|
||||
var recordedNewlyPassing []uint
|
||||
ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, gotHost *fleet.Host, results map[uint]*bool, updated time.Time,
|
||||
deferred bool, newlyPassingPolicyIDs []uint,
|
||||
) error {
|
||||
recordedResults = results
|
||||
recordedNewlyPassing = newlyPassingPolicyIDs
|
||||
host = gotHost
|
||||
return nil
|
||||
}
|
||||
|
||||
flippingCallCount = 0
|
||||
// Record a query execution.
|
||||
err = svc.SubmitDistributedQueryResults(
|
||||
ctx,
|
||||
|
|
@ -3726,6 +3744,14 @@ func TestPolicyWebhooks(t *testing.T) {
|
|||
require.NotNil(t, recordedResults[3])
|
||||
require.False(t, *recordedResults[3])
|
||||
|
||||
// Verify FlippingPoliciesForHost was called exactly once with all 3 policy results.
|
||||
require.Equal(t, 1, flippingCallCount, "FlippingPoliciesForHost should be called exactly once per SubmitDistributedQueryResults")
|
||||
require.Len(t, flippingIncomingResults, 3, "FlippingPoliciesForHost should receive all policy results")
|
||||
// Verify pre-computed newlyPassingPolicyIDs was forwarded (FlippingPoliciesForHost returned nil for newPassing,
|
||||
// but the caller normalizes it to a non-nil empty slice).
|
||||
require.NotNil(t, recordedNewlyPassing, "pre-computed newlyPassingPolicyIDs should be forwarded to RecordPolicyQueryExecutions")
|
||||
require.Empty(t, recordedNewlyPassing)
|
||||
|
||||
cmpSets := func(expSets map[uint][]fleet.PolicySetHost) error {
|
||||
actualSets, err := failingPolicySet.ListSets()
|
||||
if err != nil {
|
||||
|
|
@ -3805,9 +3831,12 @@ func TestPolicyWebhooks(t *testing.T) {
|
|||
ds.FlippingPoliciesForHostFunc = func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint,
|
||||
err error,
|
||||
) {
|
||||
flippingCallCount++
|
||||
flippingIncomingResults = incomingResults
|
||||
return []uint{1}, []uint{3}, nil
|
||||
}
|
||||
|
||||
flippingCallCount = 0
|
||||
// Record another query execution.
|
||||
err = svc.SubmitDistributedQueryResults(
|
||||
ctx,
|
||||
|
|
@ -3832,6 +3861,12 @@ func TestPolicyWebhooks(t *testing.T) {
|
|||
require.NotNil(t, recordedResults[3])
|
||||
require.True(t, *recordedResults[3])
|
||||
|
||||
// Verify single call and correct forwarding when there are actual flips.
|
||||
require.Equal(t, 1, flippingCallCount, "FlippingPoliciesForHost should be called exactly once")
|
||||
require.Len(t, flippingIncomingResults, 3)
|
||||
require.NotNil(t, recordedNewlyPassing)
|
||||
require.ElementsMatch(t, []uint{3}, recordedNewlyPassing, "newlyPassingPolicyIDs should be forwarded from FlippingPoliciesForHost")
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
err = cmpSets(map[uint][]fleet.PolicySetHost{
|
||||
1: {{
|
||||
|
|
|
|||
Loading…
Reference in a new issue