diff --git a/server/datastore/mysql/policies.go b/server/datastore/mysql/policies.go index e6618a92d7..f2f1985a69 100644 --- a/server/datastore/mysql/policies.go +++ b/server/datastore/mysql/policies.go @@ -133,7 +133,7 @@ func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemo } if shouldRemoveAllPolicyMemberships { - return cleanupPolicyMembership(ctx, ds.writer(ctx), p.ID) + return ds.cleanupPolicyMembershipForPolicy(ctx, p.ID) } return cleanupPolicyMembershipOnPolicyUpdate(ctx, ds.writer(ctx), p.ID, p.Platform) } @@ -759,21 +759,37 @@ func cleanupPolicyMembershipOnPolicyUpdate(ctx context.Context, db sqlx.ExecerCo // cleanupPolicyMembership is similar to cleanupPolicyMembershipOnPolicyUpdate but without the platform constraints. // Used when we want to remove all policy membership. -func cleanupPolicyMembership(ctx context.Context, db sqlx.ExecerContext, policyID uint) error { +func (ds *Datastore) cleanupPolicyMembershipForPolicy(ctx context.Context, policyID uint) error { + // delete all policy memberships for the policy delStmt := ` - DELETE - pm - FROM - policy_membership pm - LEFT JOIN - hosts h - ON - pm.host_id = h.id - WHERE - pm.policy_id = ?` + DELETE + pm + FROM + policy_membership pm + LEFT JOIN + hosts h + ON + pm.host_id = h.id + WHERE + pm.policy_id = ? + ` - _, err := db.ExecContext(ctx, delStmt, policyID) - return ctxerr.Wrap(ctx, err, "cleanup policy membership") + _, err := ds.writer(ctx).ExecContext(ctx, delStmt, policyID) + if err != nil { + return ctxerr.Wrap(ctx, err, "cleanup policy membership") + } + + // delete all policy stats for the policy + // wrapping in a retry to avoid deadlocks with the cleanups_then_aggregation cron job + err = ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { + _, err := tx.ExecContext(ctx, `DELETE FROM policy_stats WHERE policy_id = ?`, policyID) + return err + }) + if err != nil { + return ctxerr.Wrap(ctx, err, "cleanup policy stats") + } + + return nil } // CleanupPolicyMembership deletes the host's membership from policies that diff --git a/server/datastore/mysql/policies_test.go b/server/datastore/mysql/policies_test.go index 3529f18a6e..101eb6f925 100644 --- a/server/datastore/mysql/policies_test.go +++ b/server/datastore/mysql/policies_test.go @@ -54,6 +54,7 @@ func TestPolicies(t *testing.T) { {"TestListTeamPoliciesCanPaginate", testListTeamPoliciesCanPaginate}, {"TestCountPolicies", testCountPolicies}, {"TestUpdatePolicyHostCounts", testUpdatePolicyHostCounts}, + {"TestCachedPolicyCountDeletesOnPolicyChange", testCachedPolicyCountDeletesOnPolicyChange}, {"TestPoliciesListOptions", testPoliciesListOptions}, } for _, c := range cases { @@ -1463,6 +1464,99 @@ func testPoliciesSave(t *testing.T, ds *Datastore) { require.Len(t, rows, 0) } +func testCachedPolicyCountDeletesOnPolicyChange(t *testing.T, ds *Datastore) { + ctx := context.Background() + user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true) + team1, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team1"}) + require.NoError(t, err) + + teamHost, err := ds.NewHost(ctx, &fleet.Host{ + OsqueryHostID: ptr.String("test-1"), + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: ptr.String("test-1"), + UUID: "test-1", + Hostname: "foo.local", + Platform: "windows", + }) + require.NoError(t, err) + require.NoError(t, ds.AddHostsToTeam(ctx, &team1.ID, []uint{teamHost.ID})) + + globalHost, err := ds.NewHost(ctx, &fleet.Host{ + OsqueryHostID: ptr.String("test-2"), + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: ptr.String("test-2"), + UUID: "test-2", + Hostname: "foo.local", + Platform: "windows", + }) + require.NoError(t, err) + + globalPolicy, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{ + Name: "global query", + Query: "select 1;", + Description: "global query desc", + Resolution: "global query resolution", + }) + require.NoError(t, err) + + teamPolicy, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{ + Name: "team query", + Query: "select 1;", + Description: "team query desc", + Resolution: "team query resolution", + }) + require.NoError(t, err) + + // teamHost and globalHost pass all policies + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, teamHost, map[uint]*bool{globalPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, teamHost, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), teamPolicy.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, globalHost, map[uint]*bool{globalPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) + + err = ds.UpdateHostPolicyCounts(ctx) + require.NoError(t, err) + + globalPolicy, err = ds.Policy(ctx, globalPolicy.ID) + require.NoError(t, err) + assert.Equal(t, uint(2), globalPolicy.PassingHostCount) + teamPolicies, inheritedPolicies, err := ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}) + require.NoError(t, err) + require.Len(t, teamPolicies, 1) + require.Len(t, inheritedPolicies, 1) + assert.Equal(t, uint(1), teamPolicies[0].PassingHostCount) + assert.Equal(t, uint(1), inheritedPolicies[0].PassingHostCount) + + // Update the global policy sql to trigger a cache invalidation + err = ds.SavePolicy(ctx, globalPolicy, true) + require.NoError(t, err) + + globalPolicy, err = ds.Policy(ctx, globalPolicy.ID) + require.NoError(t, err) + assert.Equal(t, uint(0), globalPolicy.PassingHostCount) + teamPolicies, inheritedPolicies, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}) + require.NoError(t, err) + require.Len(t, teamPolicies, 1) + require.Len(t, inheritedPolicies, 1) + assert.Equal(t, uint(1), teamPolicies[0].PassingHostCount) + assert.Equal(t, uint(0), inheritedPolicies[0].PassingHostCount) + + // Update the team policy sql to trigger a cache invalidation + err = ds.SavePolicy(ctx, teamPolicy, true) + require.NoError(t, err) + + teamPolicies, inheritedPolicies, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}) + require.NoError(t, err) + require.Len(t, teamPolicies, 1) + require.Len(t, inheritedPolicies, 1) + assert.Equal(t, uint(0), teamPolicies[0].PassingHostCount) + assert.Equal(t, uint(0), inheritedPolicies[0].PassingHostCount) +} + func testPoliciesDelUser(t *testing.T, ds *Datastore) { user1 := test.NewUser(t, ds, "User1", "user1@example.com", true) user2 := test.NewUser(t, ds, "User2", "user2@example.com", true)