diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index 849861b932..9050d8928d 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -27,6 +27,7 @@ import ( // Since many hosts may have issues, we need to batch the inserts of host issues. // This is a variable, so it can be adjusted during unit testing. var hostIssuesInsertBatchSize = 10000 +var hostIssuesUpdateFailingPoliciesBatchSize = 10000 // A large number of hosts could be changing teams at once, so we need to batch this operation to prevent excessive locks var addHostsToTeamBatchSize = 10000 @@ -5269,37 +5270,71 @@ func updateHostIssuesFailingPolicies(ctx context.Context, tx sqlx.ExecerContext, return nil } - masterStmt := ` + // For 1 host, we use a single statement to update the host_issues entry. + if len(hostIDs) == 1 { + stmt := ` + INSERT INTO host_issues (host_id, failing_policies_count, total_issues_count) + SELECT host_id.id, COALESCE(SUM(!pm.passes), 0), COALESCE(SUM(!pm.passes), 0) + FROM policy_membership pm + RIGHT JOIN (SELECT ? as id) as host_id + ON pm.host_id = host_id.id + GROUP BY host_id.id + ON DUPLICATE KEY UPDATE + failing_policies_count = VALUES(failing_policies_count), + total_issues_count = VALUES(failing_policies_count) + critical_vulnerabilities_count` + if _, err := tx.ExecContext(ctx, stmt, hostIDs[0]); err != nil { + return ctxerr.Wrap(ctx, err, "updating failing policies in host issues for one host") + } + return nil + } + + // Clear host_issues entries for hosts that are not in policy_membership + clearStmt := ` + UPDATE host_issues + SET failing_policies_count = 0, total_issues_count = critical_vulnerabilities_count + WHERE host_id NOT IN ( + SELECT host_id + FROM policy_membership + WHERE host_id IN (?) + ) AND host_id IN (?)` + + // Insert/update host_issues entries for hosts that are in policy_membership. + // Initially, these two statements were combined into one statement using `SELECT ? AS id UNION ALL` approach to include the host IDs that + // were not in policy_membership (similar how the above query for 1 host works). However, in load testing we saw an error: Thread stack overrun: 242191 bytes used of a 262144 byte stack + insertStmt := ` INSERT INTO host_issues (host_id, failing_policies_count, total_issues_count) - SELECT host_ids.id, COALESCE(SUM(!pm.passes), 0), COALESCE(SUM(!pm.passes), 0) + SELECT pm.host_id, COALESCE(SUM(!pm.passes), 0), COALESCE(SUM(!pm.passes), 0) FROM policy_membership pm - RIGHT JOIN (%s) as host_ids - ON pm.host_id = host_ids.id - GROUP BY host_ids.id + WHERE pm.host_id IN (?) + GROUP BY pm.host_id ON DUPLICATE KEY UPDATE failing_policies_count = VALUES(failing_policies_count), total_issues_count = VALUES(failing_policies_count) + critical_vulnerabilities_count` // Large number of hosts could be impacted, so we update their host issues entries in batches to reduce lock time. - for i := 0; i < len(hostIDs); i += hostIssuesInsertBatchSize { + for i := 0; i < len(hostIDs); i += hostIssuesUpdateFailingPoliciesBatchSize { start := i - end := i + hostIssuesInsertBatchSize + end := i + hostIssuesUpdateFailingPoliciesBatchSize if end > len(hostIDs) { end = len(hostIDs) } - totalToProcess := end - start hostIDsBatch := hostIDs[start:end] - inlineTable := strings.TrimSuffix( - strings.Repeat("SELECT ? AS id UNION ALL ", totalToProcess), " UNION ALL ", - ) - - args := make([]interface{}, totalToProcess) - for j := range hostIDsBatch { - args[j] = hostIDsBatch[j] + // Zero out failing policies count for hosts that are not in policy_membership + stmt, args, err := sqlx.In(clearStmt, hostIDsBatch, hostIDsBatch) + if err != nil { + return ctxerr.Wrap(ctx, err, "building IN statement for clearing host failing policy issues") } - if _, err := tx.ExecContext(ctx, fmt.Sprintf(masterStmt, inlineTable), args...); err != nil { - return ctxerr.Wrap(ctx, err, "update failing policies in host issues") + if _, err := tx.ExecContext(ctx, stmt, args...); err != nil { + return ctxerr.Wrap(ctx, err, "clearing failing policies in host issues") + } + // Update failing policies count for hosts that are in policy_membership + stmt, args, err = sqlx.In(insertStmt, hostIDsBatch) + if err != nil { + return ctxerr.Wrap(ctx, err, "building IN statement for updating host failing policy issues") + } + if _, err := tx.ExecContext(ctx, stmt, args...); err != nil { + return ctxerr.Wrap(ctx, err, "updating failing policies in host issues") } } return nil diff --git a/server/datastore/mysql/hosts_test.go b/server/datastore/mysql/hosts_test.go index d9881e17f1..428f6ec39b 100644 --- a/server/datastore/mysql/hosts_test.go +++ b/server/datastore/mysql/hosts_test.go @@ -9381,12 +9381,15 @@ func testUpdateHostIssues(t *testing.T, ds *Datastore) { // Test with small batch size and premium license ctx = license.NewContext(ctx, &fleet.LicenseInfo{Tier: fleet.TierPremium}) insertBatchSizeOrig := hostIssuesInsertBatchSize + updateBatchSizeOrig := hostIssuesUpdateFailingPoliciesBatchSize t.Cleanup( func() { hostIssuesInsertBatchSize = insertBatchSizeOrig + hostIssuesUpdateFailingPoliciesBatchSize = updateBatchSizeOrig }, ) hostIssuesInsertBatchSize = 2 + hostIssuesUpdateFailingPoliciesBatchSize = 2 assert.NoError(t, ds.UpdateHostIssuesFailingPolicies(ctx, hostIDs)) assert.NoError(t, ds.UpdateHostIssuesVulnerabilities(ctx))