From 9c714c544dc583321c9bb02e0349de5a1287d978 Mon Sep 17 00:00:00 2001 From: Victor Lyuboslavsky Date: Mon, 6 May 2024 09:48:37 -0500 Subject: [PATCH] Optimized policy_stats updates to NOT lock the policy_membership table (#18720) #16562 Optimized policy_stats updates to NOT lock the policy_membership table. This should improve deployment performance with many global policies and team hosts. The original implementation that used INSERT ... SELECT (SELECT COUNT(*)) ... caused performance issues. Given 50 global policies, 10 teams, and 10,000 hosts per team, the INSERT query took 30-60 seconds to complete. Since it was an INSERT query, it blocked other hosts from updating their policy results in policy_membership. Now, we separate the INSERT from the SELECT, since SELECT by itself does not block other hosts from updating their policy results. In addition, we process one global policy at a time, which reduces the time to complete the SELECT query to <2 seconds, and limits the memory usage. We are not using a transaction to reduce locks. This means that INSERT may fail if the policy was deleted by a parallel process. Also, the INSERT may overwrite a clearing of the stats. This is acceptable, since these are very rare cases. We log and proceed in that case. # Checklist for submitter If some of the following don't apply, delete the relevant line. - [x] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - [x] Added/updated tests - [x] Manual QA for all new/changed functionality --- changes/16562-policy_stats-lock | 1 + server/datastore/mysql/policies.go | 111 ++++++++++++----- server/datastore/mysql/policies_test.go | 151 ++++++++++++++++++++---- 3 files changed, 212 insertions(+), 51 deletions(-) create mode 100644 changes/16562-policy_stats-lock diff --git a/changes/16562-policy_stats-lock b/changes/16562-policy_stats-lock new file mode 100644 index 0000000000..6c9b551037 --- /dev/null +++ b/changes/16562-policy_stats-lock @@ -0,0 +1 @@ +Optimized policy_stats updates to NOT lock the policy_membership table. This should improve performance on deployments with a large number of global policies and team hosts. diff --git a/server/datastore/mysql/policies.go b/server/datastore/mysql/policies.go index 0f07441110..3701c84b48 100644 --- a/server/datastore/mysql/policies.go +++ b/server/datastore/mysql/policies.go @@ -1245,40 +1245,95 @@ func (ds *Datastore) UpdateHostPolicyCounts(ctx context.Context) error { // NOTE these queries are duplicated in the below migration. Updates // to these queries should be reflected there as well. // https://github.com/fleetdm/fleet/blob/main/server/datastore/mysql/migrations/tables/20231215122713_InsertPolicyStatsData.go#L12 + // This implementation should be functionally equivalent to the migration. // Update Counts for Inherited Global Policies for each Team - _, err := ds.writer(ctx).ExecContext(ctx, ` - INSERT INTO policy_stats (policy_id, inherited_team_id, passing_host_count, failing_host_count) - SELECT - p.id, - t.id AS inherited_team_id, - ( - SELECT COUNT(*) - FROM policy_membership pm - INNER JOIN hosts h ON pm.host_id = h.id - WHERE pm.policy_id = p.id AND pm.passes = true AND h.team_id = t.id - ) AS passing_host_count, - ( - SELECT COUNT(*) - FROM policy_membership pm - INNER JOIN hosts h ON pm.host_id = h.id - WHERE pm.policy_id = p.id AND pm.passes = false AND h.team_id = t.id - ) AS failing_host_count - FROM policies p - CROSS JOIN teams t - WHERE p.team_id IS NULL - GROUP BY p.id, t.id - ON DUPLICATE KEY UPDATE - updated_at = NOW(), - passing_host_count = VALUES(passing_host_count), - failing_host_count = VALUES(failing_host_count); - `) + // The original implementation that used INSERT ... SELECT (SELECT COUNT(*)) ... caused performance issues. + // Given 50 global policies, 10 teams, and 10,000 hosts per team, the INSERT query took 30-60 seconds to complete. + // Since it was an INSERT query, it blocked other hosts from updating their policy results in policy_membership. + + // Now, we separate the INSERT from the SELECT, since SELECT by itself does not block other hosts from updating their policy results. + // In addition, we process one global policy at a time, which reduces the time to complete the SELECT query to <2 seconds, and limits the memory usage. + // We are not using a transaction to reduce locks. This means that INSERT may fail if the policy was deleted by a parallel process. + // Also, the INSERT may overwrite a clearing of the stats. This is acceptable, since these are very rare cases. We log and proceed in that case. + + db := ds.writer(ctx) + + // Inherited policies are only relevant for teams, so we check whether we have teams + var hasTeams bool + err := sqlx.GetContext(ctx, db, &hasTeams, `SELECT 1 FROM teams`) if err != nil { - return ctxerr.Wrap(ctx, err, "update host policy counts for inherited global policies") + if errors.Is(err, sql.ErrNoRows) { + // No teams, so no inherited policies + hasTeams = false + } else { + return ctxerr.Wrap(ctx, err, "count teams") + } + } + + if hasTeams { + globalPolicies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{}) + if err != nil { + return ctxerr.Wrap(ctx, err, "list global policies") + } + type policyStat struct { + PolicyID uint `db:"policy_id"` + InheritedTeamID uint `db:"inherited_team_id"` + PassingHostCount uint `db:"passing_host_count"` + FailingHostCount uint `db:"failing_host_count"` + } + var policyStats []policyStat + for _, policy := range globalPolicies { + selectStmt := `SELECT + p.id as policy_id, + t.id AS inherited_team_id, + ( + SELECT COUNT(*) + FROM policy_membership pm + INNER JOIN hosts h ON pm.host_id = h.id + WHERE pm.policy_id = p.id AND pm.passes = true AND h.team_id = t.id + ) AS passing_host_count, + ( + SELECT COUNT(*) + FROM policy_membership pm + INNER JOIN hosts h ON pm.host_id = h.id + WHERE pm.policy_id = p.id AND pm.passes = false AND h.team_id = t.id + ) AS failing_host_count + FROM policies p + CROSS JOIN teams t + WHERE p.team_id IS NULL AND p.id = ? + GROUP BY t.id, p.id` + err = sqlx.SelectContext(ctx, db, &policyStats, selectStmt, policy.ID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, sql.ErrNoRows) { + // Policy or team was deleted by a parallel process. We proceed. + level.Error(ds.logger).Log( + "msg", "policy not found for inherited global policies. Was policy or team(s) deleted?", "policy_id", policy.ID, + ) + continue + } + return ctxerr.Wrap(ctx, err, "select policy counts for inherited global policies") + } + insertStmt := `INSERT INTO policy_stats (policy_id, inherited_team_id, passing_host_count, failing_host_count) + VALUES (:policy_id, :inherited_team_id, :passing_host_count, :failing_host_count) + ON DUPLICATE KEY UPDATE + updated_at = NOW(), + passing_host_count = VALUES(passing_host_count), + failing_host_count = VALUES(failing_host_count)` + _, err = sqlx.NamedExecContext(ctx, db, insertStmt, policyStats) + if err != nil { + // INSERT may fail due to rare race conditions. We log and proceed. + level.Error(ds.logger).Log( + "msg", "insert policy stats for inherited global policies. Was policy deleted?", "policy_id", policy.ID, "err", err, + ) + } + } } // Update Counts for Global and Team Policies - _, err = ds.writer(ctx).ExecContext(ctx, ` + // The performance of this query is linear with the number of policies. + _, err = db.ExecContext( + ctx, ` INSERT INTO policy_stats (policy_id, inherited_team_id, passing_host_count, failing_host_count) SELECT p.id, diff --git a/server/datastore/mysql/policies_test.go b/server/datastore/mysql/policies_test.go index 90d9015489..1e6980736e 100644 --- a/server/datastore/mysql/policies_test.go +++ b/server/datastore/mysql/policies_test.go @@ -2933,34 +2933,18 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) { policy, err := ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: "global policy 1"}) require.NoError(t, err) - team, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"}) - require.NoError(t, err) - - // create 4 team hosts - var teamHosts []*fleet.Host - for i := 0; i < 4; i++ { - h, err := ds.NewHost(context.Background(), &fleet.Host{OsqueryHostID: ptr.String(fmt.Sprintf("host%d", i)), NodeKey: ptr.String(fmt.Sprintf("host%d", i)), TeamID: &team.ID}) - require.NoError(t, err) - teamHosts = append(teamHosts, h) - } - // create 4 global hosts var globalHosts []*fleet.Host - for i := 4; i < 8; i++ { - h, err := ds.NewHost(context.Background(), &fleet.Host{OsqueryHostID: ptr.String(fmt.Sprintf("host%d", i)), NodeKey: ptr.String(fmt.Sprintf("host%d", i)), TeamID: nil}) + for i := 100; i < 104; i++ { + h, err := ds.NewHost( + context.Background(), + &fleet.Host{OsqueryHostID: ptr.String(fmt.Sprintf("host%d", i)), NodeKey: ptr.String(fmt.Sprintf("host%d", i)), TeamID: nil}, + ) require.NoError(t, err) globalHosts = append(globalHosts, h) } - // add policy responses - for _, h := range teamHosts { - res := map[uint]*bool{ - policy.ID: ptr.Bool(true), - } - err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false) - require.NoError(t, err) - } - + // add policy responses to global hosts for _, h := range globalHosts { res := map[uint]*bool{ policy.ID: ptr.Bool(true), @@ -2986,7 +2970,7 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) { policy, err = ds.Policy(context.Background(), policy.ID) require.NoError(t, err) require.Equal(t, uint(0), policy.FailingHostCount) - require.Equal(t, uint(8), policy.PassingHostCount) + require.Equal(t, uint(4), policy.PassingHostCount) require.NotNil(t, policy.HostCountUpdatedAt) assert.True( t, policy.HostCountUpdatedAt.Compare(now) >= 0, fmt.Sprintf("reference:%v HostCountUpdatedAt:%v", now, *policy.HostCountUpdatedAt), @@ -2994,6 +2978,127 @@ func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) { assert.True( t, policy.HostCountUpdatedAt.Compare(later) < 0, fmt.Sprintf("later:%v HostCountUpdatedAt:%v", later, *policy.HostCountUpdatedAt), ) + + team, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"}) + require.NoError(t, err) + + // create 4 team hosts + var teamHosts []*fleet.Host + for i := 0; i < 4; i++ { + h, err := ds.NewHost(context.Background(), &fleet.Host{OsqueryHostID: ptr.String(fmt.Sprintf("host%d", i)), NodeKey: ptr.String(fmt.Sprintf("host%d", i)), TeamID: &team.ID}) + require.NoError(t, err) + teamHosts = append(teamHosts, h) + } + + // add policy responses to team hosts + for _, h := range teamHosts { + var result *bool + switch h.ID % 5 { + case 0, 1: // 2 fails + result = ptr.Bool(false) + case 2: // 1 pass + result = ptr.Bool(true) + default: + // remain null + } + + res := map[uint]*bool{ + policy.ID: result, + } + err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false) + require.NoError(t, err) + } + + // update policy host counts + now = time.Now().Truncate(time.Second) + later = now.Add(10 * time.Second) + err = ds.UpdateHostPolicyCounts(context.Background()) + require.NoError(t, err) + + // check policy host counts + policy, err = ds.Policy(context.Background(), policy.ID) + require.NoError(t, err) + require.Equal(t, uint(2), policy.FailingHostCount) + require.Equal(t, uint(5), policy.PassingHostCount) + require.NotNil(t, policy.HostCountUpdatedAt) + assert.True( + t, policy.HostCountUpdatedAt.Compare(now) >= 0, fmt.Sprintf("reference:%v HostCountUpdatedAt:%v", now, *policy.HostCountUpdatedAt), + ) + assert.True( + t, policy.HostCountUpdatedAt.Compare(later) < 0, fmt.Sprintf("later:%v HostCountUpdatedAt:%v", later, *policy.HostCountUpdatedAt), + ) + + // new global policy + policy2, err := ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: "global policy 2"}) + require.NoError(t, err) + + // new team + team2, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team2"}) + require.NoError(t, err) + + // create 4 team2 hosts + for i := 4; i < 8; i++ { + h, err := ds.NewHost( + context.Background(), &fleet.Host{ + OsqueryHostID: ptr.String(fmt.Sprintf("host%d", i)), NodeKey: ptr.String(fmt.Sprintf("host%d", i)), TeamID: &team2.ID, + }, + ) + require.NoError(t, err) + teamHosts = append(teamHosts, h) + } + + // Update policy results for all hosts. + // All fail policy 1, all pass policy 2 + for _, h := range globalHosts { + res := map[uint]*bool{ + policy.ID: ptr.Bool(false), + policy2.ID: ptr.Bool(true), + } + err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false) + require.NoError(t, err) + } + for _, h := range teamHosts { + res := map[uint]*bool{ + policy.ID: ptr.Bool(false), + policy2.ID: ptr.Bool(true), + } + err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false) + require.NoError(t, err) + } + + // update policy host counts + now = time.Now().Truncate(time.Second) + later = now.Add(10 * time.Second) + err = ds.UpdateHostPolicyCounts(context.Background()) + require.NoError(t, err) + + // check policy 1 host counts + policy, err = ds.Policy(context.Background(), policy.ID) + require.NoError(t, err) + require.Equal(t, uint(12), policy.FailingHostCount) + require.Equal(t, uint(0), policy.PassingHostCount) + require.NotNil(t, policy.HostCountUpdatedAt) + assert.True( + t, policy.HostCountUpdatedAt.Compare(now) >= 0, fmt.Sprintf("reference:%v HostCountUpdatedAt:%v", now, *policy.HostCountUpdatedAt), + ) + assert.True( + t, policy.HostCountUpdatedAt.Compare(later) < 0, fmt.Sprintf("later:%v HostCountUpdatedAt:%v", later, *policy.HostCountUpdatedAt), + ) + + // check policy 2 host counts + policy2, err = ds.Policy(context.Background(), policy2.ID) + require.NoError(t, err) + require.Equal(t, uint(0), policy2.FailingHostCount) + require.Equal(t, uint(12), policy2.PassingHostCount) + require.NotNil(t, policy2.HostCountUpdatedAt) + assert.True( + t, policy2.HostCountUpdatedAt.Compare(now) >= 0, + fmt.Sprintf("reference:%v HostCountUpdatedAt:%v", now, *policy2.HostCountUpdatedAt), + ) + assert.True( + t, policy2.HostCountUpdatedAt.Compare(later) < 0, fmt.Sprintf("later:%v HostCountUpdatedAt:%v", later, *policy2.HostCountUpdatedAt), + ) + } func testPoliciesNameUnicode(t *testing.T, ds *Datastore) {