Improve contention around policy_membership table (35484) (#40853)

Resolves #35484

Concurrent execution of GitOps apply runs and
RecordPolicyQueryExecutions led to database locking issues when the
policy_membership table was large. This occurred because the cleanup
process (DELETE operations) was bundled within the same transaction as
the GitOps policy updates. To resolve this, the deletion logic has been
batched and moved outside the primary GitOps transaction, reducing lock
contention.
This commit is contained in:
Juan Fernandez 2026-03-16 15:12:25 -04:00 committed by GitHub
parent 700370a298
commit 139e365d42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 704 additions and 168 deletions

View file

@ -0,0 +1 @@
* Fixed database locking issues on the policy_membership table by batching cleanup DELETE operations and moving them outside the primary GitOps apply transaction.

View file

@ -0,0 +1,20 @@
package tables
import "database/sql"
func init() {
MigrationClient.AddMigration(Up_20260316000001, Down_20260316000001)
}
func Up_20260316000001(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE policies
ADD COLUMN needs_full_membership_cleanup TINYINT(1) NOT NULL DEFAULT 0,
ALGORITHM=INSTANT
`)
return err
}
func Down_20260316000001(tx *sql.Tx) error {
return nil
}

View file

@ -61,6 +61,21 @@ var (
var policySearchColumns = []string{"p.name"}
// policyMembershipDeleteBatchSize is the number of host IDs to delete per batch
// when cleaning up policy_membership rows. Exposed as a var (not a const) so that
// integration tests can use a smaller value to exercise multi-batch code paths.
// Do not mutate concurrently (no synchronization); tests that override this must
// not use t.Parallel().
var policyMembershipDeleteBatchSize = 500
// policyCleanupArgs holds the arguments needed to run cleanupPolicy
type policyCleanupArgs struct {
policyID uint
platform string
shouldRemoveAllPolicyMemberships bool
removePolicyStats bool
}
func (ds *Datastore) NewGlobalPolicy(ctx context.Context, authorID *uint, args fleet.PolicyPayload) (*fleet.Policy, error) {
var newPolicy *fleet.Policy
@ -479,6 +494,7 @@ func cleanupPolicy(
removePolicyStats bool, logger *slog.Logger,
) error {
var err error
if shouldRemoveAllPolicyMemberships {
err = cleanupPolicyMembershipForPolicy(ctx, queryerContext, extContext, policyID)
} else {
@ -487,6 +503,7 @@ func cleanupPolicy(
if err != nil {
return err
}
if removePolicyStats {
// delete all policy stats for the policy
fn := func(tx sqlx.ExtContext) error {
@ -627,15 +644,14 @@ func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *flee
// semantically equivalent, even though here it processes a single host and
// in async mode it processes a batch of hosts).
err = ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
err = ds.withTx(ctx, func(tx sqlx.ExtContext) error {
if len(results) > 0 {
query := fmt.Sprintf(
`INSERT INTO policy_membership (updated_at, policy_id, host_id, passes)
VALUES %s ON DUPLICATE KEY UPDATE updated_at=VALUES(updated_at), passes=VALUES(passes)`,
VALUES %s ON DUPLICATE KEY UPDATE updated_at=VALUES(updated_at), passes=VALUES(passes)`,
strings.Join(bindvars, ","),
)
_, err := tx.ExecContext(ctx, query, vals...)
if err != nil {
if _, err := tx.ExecContext(ctx, query, vals...); err != nil {
return ctxerr.Wrapf(ctx, err, "insert policy_membership (%v)", vals)
}
@ -658,14 +674,11 @@ func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *flee
}
}
}
// if we are deferring host updates, we return at this point and do the change outside of the tx
if deferredSaveHost {
return nil
}
if _, err := tx.ExecContext(ctx, `UPDATE hosts SET policy_updated_at = ? WHERE id=?`, updated, host.ID); err != nil {
return ctxerr.Wrap(ctx, err, "updating hosts policy updated at")
if !deferredSaveHost {
if _, err := tx.ExecContext(ctx, `UPDATE hosts SET policy_updated_at = ? WHERE id=?`, updated, host.ID); err != nil {
return ctxerr.Wrap(ctx, err, "updating hosts policy updated at")
}
}
return nil
})
@ -1334,14 +1347,15 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs
}
}
// Get the query and platforms of the current policies so that we can check if query or platform changed later, if needed
// Get the query and platforms of the current policies so that we can check if the query or platform changed later, if needed
type policyLite struct {
Name string `db:"name"`
Query string `db:"query"`
Platforms string `db:"platforms"`
SoftwareInstallerID *uint `db:"software_installer_id"`
VPPAppsTeamsID *uint `db:"vpp_apps_teams_id"`
ScriptID *uint `db:"script_id"`
Name string `db:"name"`
Query string `db:"query"`
Platforms string `db:"platforms"`
SoftwareInstallerID *uint `db:"software_installer_id"`
VPPAppsTeamsID *uint `db:"vpp_apps_teams_id"`
ScriptID *uint `db:"script_id"`
NeedsFullMembershipCleanup bool `db:"needs_full_membership_cleanup"`
}
teamIDToPoliciesByName := make(map[*uint]map[string]policyLite, len(teamIDToPolicies))
for teamID, teamPolicySpecs := range teamIDToPolicies {
@ -1355,10 +1369,10 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs
var args []interface{}
var err error
if teamID == nil {
query, args, err = sqlx.In("SELECT name, query, platforms, software_installer_id, vpp_apps_teams_id, script_id FROM policies WHERE team_id IS NULL AND name IN (?)", policyNames)
query, args, err = sqlx.In("SELECT name, query, platforms, software_installer_id, vpp_apps_teams_id, script_id, needs_full_membership_cleanup FROM policies WHERE team_id IS NULL AND name IN (?)", policyNames)
} else {
query, args, err = sqlx.In(
"SELECT name, query, platforms, software_installer_id, vpp_apps_teams_id, script_id FROM policies WHERE team_id = ? AND name IN (?)", *teamID, policyNames,
"SELECT name, query, platforms, software_installer_id, vpp_apps_teams_id, script_id, needs_full_membership_cleanup FROM policies WHERE team_id = ? AND name IN (?)", *teamID, policyNames,
)
}
if err != nil {
@ -1374,7 +1388,11 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs
}
}
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
var pendingCleanups []policyCleanupArgs
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
// Reset on retry so we don't accumulate duplicate cleanup entries.
pendingCleanups = pendingCleanups[:0]
query := fmt.Sprintf(
`
INSERT INTO policies (
@ -1515,6 +1533,14 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs
return ctxerr.Wrap(ctx, err, "select policies id")
}
}
policyID := uint(lastID) //nolint:gosec // dismiss G115
// If a previous cleanup was interrupted (e.g. server crash between
// transaction commit and cleanup completion), re-trigger cleanup on any GitOps retry,
// even if the policy itself didn't change this run.
if prev, ok := teamIDToPoliciesByName[teamID][spec.Name]; ok && prev.NeedsFullMembershipCleanup {
shouldRemoveAllPolicyMemberships = true
}
// Create LabelIdents to send to updatePolicyLabelsTx.
// Right now we only need the names.
@ -1527,9 +1553,10 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs
for _, labelExclude := range spec.LabelsExcludeAny {
labelsExcludeAnyIdents = append(labelsExcludeAnyIdents, fleet.LabelIdent{LabelName: labelExclude})
}
err = updatePolicyLabelsTx(ctx, tx, &fleet.Policy{
PolicyData: fleet.PolicyData{
ID: uint(lastID), //nolint:gosec // dismiss G115
ID: policyID,
LabelsIncludeAny: labelsIncludeAnyIdents,
LabelsExcludeAny: labelsExcludeAnyIdents,
},
@ -1538,22 +1565,62 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs
return ctxerr.Wrap(ctx, err, "exec policies update labels")
}
// Run cleanup after labels are updated so the cleanup function can
// query the current label criteria from the database.
// Always run cleanup since labels may have changed even if the main policy
// fields didn't (the cleanup function is safe to call and will only delete
// memberships that don't match current criteria).
if err = cleanupPolicy(
ctx, tx, tx, uint(lastID), spec.Platform, shouldRemoveAllPolicyMemberships, //nolint:gosec // dismiss G115
removePolicyStats, ds.logger,
); err != nil {
return err
// Mark this policy for cleanup, so that the policy cleanup cron job can pick it up
// in case we fail and don't retry
if shouldRemoveAllPolicyMemberships {
if _, err := tx.ExecContext(ctx,
`UPDATE policies SET needs_full_membership_cleanup = 1 WHERE id = ?`,
policyID); err != nil {
return ctxerr.Wrap(ctx, err, "setting needs_full_membership_cleanup flag")
}
}
// Defer cleanup outside the transaction to avoid long-held row locks on
// policy_membership.
pendingCleanups = append(pendingCleanups, policyCleanupArgs{
policyID: policyID,
platform: spec.Platform,
shouldRemoveAllPolicyMemberships: shouldRemoveAllPolicyMemberships,
removePolicyStats: removePolicyStats,
})
}
}
return nil
})
if err != nil {
// It's OK to return 'early' if err, since the policy cleanup
// cron job can pick up any remnants
return err
}
// Run cleanup after labels are updated so the cleanup function can
// query the current label criteria from the database.
// Always run cleanup since labels may have changed even if the main policy
// fields didn't (the cleanup function is safe to call and will only delete
// memberships that don't match current criteria).
dbCtx := ds.writer(ctx)
for _, args := range pendingCleanups {
if err := cleanupPolicy(
ctx,
dbCtx,
dbCtx,
args.policyID,
args.platform,
args.shouldRemoveAllPolicyMemberships,
args.removePolicyStats,
ds.logger,
); err != nil {
return err
}
if args.shouldRemoveAllPolicyMemberships {
if _, err := ds.writer(ctx).ExecContext(ctx,
`UPDATE policies SET needs_full_membership_cleanup = 0 WHERE id = ?`,
args.policyID); err != nil {
return ctxerr.Wrap(ctx, err, "clearing needs_full_membership_cleanup flag")
}
}
}
return nil
}
func amountPoliciesDB(ctx context.Context, db sqlx.QueryerContext) (int, error) {
@ -1739,109 +1806,111 @@ func cleanupConditionalAccessOnTeamChange(ctx context.Context, tx sqlx.ExtContex
func cleanupPolicyMembershipOnPolicyUpdate(
ctx context.Context, queryerContext sqlx.QueryerContext, db sqlx.ExecerContext, policyID uint, platforms string,
) error {
var allHostIDs []uint
// Clean up hosts that don't match the platform criteria
// Clean up hosts that don't match the platform criteria.
// Page through rows using the (policy_id, host_id) PK as a cursor so each SELECT+DELETE
// batch holds locks for as little as possible.
if platforms != "" {
delStmt := `
DELETE
pm
FROM
policy_membership pm
LEFT JOIN
hosts h
ON
pm.host_id = h.id
WHERE
pm.policy_id = ? AND
( h.id IS NULL OR
FIND_IN_SET(h.platform, ?) = 0 )`
selectStmt := `
SELECT DISTINCT
h.id
FROM
policy_membership pm
INNER JOIN
hosts h
ON
pm.host_id = h.id
WHERE
pm.policy_id = ? AND
FIND_IN_SET(h.platform, ?) = 0`
var expandedPlatforms []string
for platform := range strings.SplitSeq(platforms, ",") {
expandedPlatforms = append(expandedPlatforms, fleet.ExpandPlatform(strings.TrimSpace(platform))...)
}
expandedPlatformsStr := strings.Join(expandedPlatforms, ",")
// Find the impacted host IDs, so we can update their host issues entries
var hostIDs []uint
err := sqlx.SelectContext(ctx, queryerContext, &hostIDs, selectStmt, policyID, strings.Join(expandedPlatforms, ","))
if err != nil {
return ctxerr.Wrap(ctx, err, "select hosts to cleanup policy membership for platform")
var afterHostID uint
for {
var batchHostIDs []uint
err := sqlx.SelectContext(ctx, queryerContext, &batchHostIDs, `
SELECT pm.host_id
FROM policy_membership pm
INNER JOIN hosts h ON pm.host_id = h.id
WHERE pm.policy_id = ? AND FIND_IN_SET(h.platform, ?) = 0
AND pm.host_id > ?
ORDER BY pm.host_id ASC
LIMIT ?`, policyID, expandedPlatformsStr, afterHostID, policyMembershipDeleteBatchSize)
if err != nil {
return ctxerr.Wrap(ctx, err, "select batch of hosts to cleanup policy membership for platform")
}
if len(batchHostIDs) == 0 {
break
}
batchStmt, args, err := sqlx.In(
`DELETE FROM policy_membership WHERE policy_id = ? AND host_id IN (?)`,
policyID, batchHostIDs,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "building batch delete for platform policy membership")
}
if _, err = db.ExecContext(ctx, batchStmt, args...); err != nil {
return ctxerr.Wrap(ctx, err, "batch cleanup policy membership for platform")
}
if err := updateHostIssuesFailingPolicies(ctx, db, batchHostIDs); err != nil {
return err
}
afterHostID = batchHostIDs[len(batchHostIDs)-1]
}
allHostIDs = append(allHostIDs, hostIDs...)
_, err = db.ExecContext(ctx, delStmt, policyID, strings.Join(expandedPlatforms, ","))
if err != nil {
return ctxerr.Wrap(ctx, err, "cleanup policy membership for platform")
// Clean up orphaned memberships (host_id refs to deleted hosts, not covered by INNER JOIN above)
if _, err := db.ExecContext(ctx, `
DELETE pm FROM policy_membership pm
LEFT JOIN hosts h ON pm.host_id = h.id
WHERE pm.policy_id = ? AND h.id IS NULL`, policyID); err != nil {
return ctxerr.Wrap(ctx, err, "cleanup orphaned policy membership for platform")
}
}
labelQuery := `
FROM
policy_membership pm
WHERE
pm.policy_id = ?
AND NOT (
(
-- If the policy has no include labels, all hosts match this part.
NOT EXISTS (
SELECT 1 FROM policy_labels pl
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 0
)
-- If the policy has include labels, the host must be in at least one of them.
OR EXISTS (
SELECT 1 FROM policy_labels pl
JOIN label_membership lm ON lm.label_id = pl.label_id AND lm.host_id = pm.host_id
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 0
)
)
-- If the policy has exclude labels, the host must not be in any of them.
AND NOT EXISTS (
SELECT 1 FROM policy_labels pl
JOIN label_membership lm ON lm.label_id = pl.label_id AND lm.host_id = pm.host_id
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 1
)
)`
// Clean up hosts that don't match the label criteria.
var afterLabelHostID uint
for {
var batchHostIDs []uint
err := sqlx.SelectContext(ctx, queryerContext, &batchHostIDs, `
SELECT pm.host_id
FROM policy_membership pm
WHERE pm.policy_id = ?
AND pm.host_id > ?
AND NOT (
(
-- If the policy has no include labels, all hosts match this part.
NOT EXISTS (
SELECT 1 FROM policy_labels pl
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 0
)
-- If the policy has include labels, the host must be in at least one of them.
OR EXISTS (
SELECT 1 FROM policy_labels pl
JOIN label_membership lm ON lm.label_id = pl.label_id AND lm.host_id = pm.host_id
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 0
)
)
-- If the policy has exclude labels, the host must not be in any of them.
AND NOT EXISTS (
SELECT 1 FROM policy_labels pl
JOIN label_membership lm ON lm.label_id = pl.label_id AND lm.host_id = pm.host_id
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 1
)
)
ORDER BY pm.host_id ASC
LIMIT ?`, policyID, afterLabelHostID, policyMembershipDeleteBatchSize)
if err != nil {
return ctxerr.Wrap(ctx, err, "select batch of hosts to cleanup policy membership for labels")
}
if len(batchHostIDs) == 0 {
break
}
// Find the impacted host IDs, so we can update their host issues entries.
labelSelectStmt := `
SELECT DISTINCT
pm.host_id
` + labelQuery
// Delete memberships for hosts that don't match the label criteria.
labelDelStmt := `
DELETE pm
` + labelQuery
var labelHostIDs []uint
err := sqlx.SelectContext(ctx, queryerContext, &labelHostIDs, labelSelectStmt, policyID)
if err != nil {
return ctxerr.Wrap(ctx, err, "select hosts to cleanup policy membership for labels")
}
allHostIDs = append(allHostIDs, labelHostIDs...)
_, err = db.ExecContext(ctx, labelDelStmt, policyID)
if err != nil {
return ctxerr.Wrap(ctx, err, "cleanup policy membership for labels")
}
// Update host issues entries. This method is rarely called, so performance should not be a concern.
if err = updateHostIssuesFailingPolicies(ctx, db, allHostIDs); err != nil {
return err
batchStmt, args, err := sqlx.In(
`DELETE FROM policy_membership WHERE policy_id = ? AND host_id IN (?)`,
policyID, batchHostIDs,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "building batch delete for label policy membership")
}
if _, err = db.ExecContext(ctx, batchStmt, args...); err != nil {
return ctxerr.Wrap(ctx, err, "batch cleanup policy membership for labels")
}
if err := updateHostIssuesFailingPolicies(ctx, db, batchHostIDs); err != nil {
return err
}
afterLabelHostID = batchHostIDs[len(batchHostIDs)-1]
}
return nil
@ -1850,49 +1919,52 @@ func cleanupPolicyMembershipOnPolicyUpdate(
// cleanupPolicyMembership is similar to cleanupPolicyMembershipOnPolicyUpdate but without the platform constraints.
// Used when we want to remove all policy membership.
func cleanupPolicyMembershipForPolicy(
ctx context.Context, queryerContext sqlx.QueryerContext, exec sqlx.ExecerContext, policyID uint,
ctx context.Context,
queryerContext sqlx.QueryerContext,
exec sqlx.ExecerContext,
policyID uint,
) error {
selectStmt := `
SELECT DISTINCT
h.id
FROM
policy_membership pm
INNER JOIN
hosts h
ON
pm.host_id = h.id
WHERE
pm.policy_id = ?`
// Page through policy_membership using (policy_id, host_id) as a cursor. Selecting and deleting one
// batch at a time means we never load all host IDs into memory at once, and each DELETE holds
// its row-locks for an extremely short period of time.
var afterHostID uint
for {
var batchHostIDs []uint
err := sqlx.SelectContext(ctx, queryerContext, &batchHostIDs, `
SELECT pm.host_id
FROM policy_membership pm
INNER JOIN hosts h ON pm.host_id = h.id
WHERE pm.policy_id = ? AND pm.host_id > ?
ORDER BY pm.host_id ASC
LIMIT ?`, policyID, afterHostID, policyMembershipDeleteBatchSize)
if err != nil {
return ctxerr.Wrap(ctx, err, "select batch of hosts for policy membership cleanup")
}
if len(batchHostIDs) == 0 {
break
}
// delete all policy memberships for the policy
delStmt := `
DELETE
pm
FROM
policy_membership pm
LEFT JOIN
hosts h
ON
pm.host_id = h.id
WHERE
pm.policy_id = ?
`
// Find the impacted host IDs, so we can update their host issues entries
var hostIDs []uint
err := sqlx.SelectContext(ctx, queryerContext, &hostIDs, selectStmt, policyID)
if err != nil {
return ctxerr.Wrap(ctx, err, "select hosts to cleanup policy membership for policy")
batchStmt, args, err := sqlx.In(
`DELETE FROM policy_membership WHERE policy_id = ? AND host_id IN (?)`,
policyID, batchHostIDs,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "building batch delete for policy membership")
}
if _, err = exec.ExecContext(ctx, batchStmt, args...); err != nil {
return ctxerr.Wrap(ctx, err, "batch cleanup policy membership")
}
if err := updateHostIssuesFailingPolicies(ctx, exec, batchHostIDs); err != nil {
return err
}
afterHostID = batchHostIDs[len(batchHostIDs)-1]
}
_, err = exec.ExecContext(ctx, delStmt, policyID)
if err != nil {
return ctxerr.Wrap(ctx, err, "cleanup policy membership")
}
// Update host issues entries. This method is rarely called, so performance should not be a concern.
if err = updateHostIssuesFailingPolicies(ctx, exec, hostIDs); err != nil {
return err
// Clean up orphaned memberships (host_id refs to deleted hosts, not covered by INNER JOIN above)
if _, err := exec.ExecContext(ctx, `
DELETE pm FROM policy_membership pm
LEFT JOIN hosts h ON pm.host_id = h.id
WHERE pm.policy_id = ? AND h.id IS NULL`, policyID); err != nil {
return ctxerr.Wrap(ctx, err, "cleanup orphaned policy membership")
}
return nil
@ -1900,7 +1972,7 @@ func cleanupPolicyMembershipForPolicy(
// CleanupPolicyMembership deletes the host's membership from policies that
// have been updated recently if those hosts don't meet the policy's criteria
// anymore (e.g. if the policy's platforms has been updated from "any" - the
// anymore (e.g. if the policy's platforms have been updated from "any" - the
// empty string - to "windows", this would delete that policy's membership rows
// for any non-windows host).
func (ds *Datastore) CleanupPolicyMembership(ctx context.Context, now time.Time) error {
@ -1930,6 +2002,25 @@ func (ds *Datastore) CleanupPolicyMembership(ctx context.Context, now time.Time)
}
}
// We perform a policies clean when running gitops outside the apply transaction, the following is a 'fail-safe'
// in case the cleanup process couldn't complete due to server crashes or other unexpected events.
var fullCleanupPolIDs []uint
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &fullCleanupPolIDs,
`SELECT id FROM policies WHERE needs_full_membership_cleanup = 1`,
); err != nil {
return ctxerr.Wrap(ctx, err, "select policies needing full membership cleanup")
}
for _, polID := range fullCleanupPolIDs {
if err := cleanupPolicyMembershipForPolicy(ctx, ds.reader(ctx), ds.writer(ctx), polID); err != nil {
return ctxerr.Wrapf(ctx, err, "full membership cleanup for policy %d", polID)
}
if _, err := ds.writer(ctx).ExecContext(ctx,
`UPDATE policies SET needs_full_membership_cleanup = 0 WHERE id = ?`, polID,
); err != nil {
return ctxerr.Wrapf(ctx, err, "clear full membership cleanup flag for policy %d", polID)
}
}
return nil
}

View file

@ -88,6 +88,10 @@ func TestPolicies(t *testing.T) {
{"PolicyModificationResetsAttemptNumber", testPolicyModificationResetsAttemptNumber},
{"TeamPatchPolicy", testTeamPatchPolicy},
{"TeamPolicyAutomationFilter", testTeamPolicyAutomationFilter},
{"BatchedPolicyMembershipCleanup", testBatchedPolicyMembershipCleanup},
{"BatchedPolicyMembershipCleanupOnPolicyUpdate", testBatchedPolicyMembershipCleanupOnPolicyUpdate},
{"ApplyPolicySpecsNeedsFullMembershipCleanupFlag", testApplyPolicySpecsNeedsFullMembershipCleanupFlag},
{"CleanupPolicyMembershipCrashRecovery", testCleanupPolicyMembershipCrashRecovery},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
@ -7096,6 +7100,425 @@ func testPolicyModificationResetsAttemptNumber(t *testing.T, ds *Datastore) {
require.Equal(t, int64(0), *scriptResults[1].AttemptNumber)
}
// testBatchedPolicyMembershipCleanup verifies that cleanupPolicyMembershipForPolicy and
// cleanupPolicyMembershipOnPolicyUpdate correctly delete rows in small batches (to reduce lock
// contention) rather than in a single large DELETE, and that all memberships are fully removed.
func testBatchedPolicyMembershipCleanup(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
// Override batch size to force multiple batches with a small number of hosts.
orig := policyMembershipDeleteBatchSize
policyMembershipDeleteBatchSize = 2
t.Cleanup(func() { policyMembershipDeleteBatchSize = orig })
// Create a policy and 5 hosts (more than the batch size of 2).
pol := newTestPolicy(t, ds, user1, "batch cleanup policy", "", nil)
hosts := make([]*fleet.Host, 5)
for i := range hosts {
id := fmt.Sprintf("batch-cleanup-%d", i)
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: "linux",
})
require.NoError(t, err)
hosts[i] = h
}
// Record failing results for all hosts so they all have policy_membership rows and host_issues entries.
for _, h := range hosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
}
// Collect host IDs for scoped host_issues assertions (avoids flakiness if other tests
// leave rows in host_issues).
hostIDs := make([]uint, len(hosts))
for i, h := range hosts {
hostIDs[i] = h.ID
}
hostIssuesQ, hostIssuesArgs, err := sqlx.In(
`SELECT COUNT(*) FROM host_issues WHERE host_id IN (?) AND total_issues_count > 0`, hostIDs,
)
require.NoError(t, err)
// Confirm all memberships exist.
var count int
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
require.Equal(t, 5, count)
// Confirm all hosts have failing policy issues.
require.NoError(t, ds.writer(ctx).Get(&count, hostIssuesQ, hostIssuesArgs...))
require.Equal(t, 5, count)
// Run the full cleanup function directly (simulates what ApplyPolicySpecs triggers when a
// query changes — shouldRemoveAllPolicyMemberships == true).
err = cleanupPolicyMembershipForPolicy(ctx, ds.reader(ctx), ds.writer(ctx), pol.ID)
require.NoError(t, err)
// All policy_membership rows must be gone.
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
assert.Zero(t, count)
// host_issues must be updated (no more failing policies for those hosts).
require.NoError(t, ds.writer(ctx).Get(&count, hostIssuesQ, hostIssuesArgs...))
assert.Zero(t, count)
}
// testBatchedPolicyMembershipCleanupOnPolicyUpdate verifies that cleanupPolicyMembershipOnPolicyUpdate
// deletes rows in batches for both the platform and label sections.
func testBatchedPolicyMembershipCleanupOnPolicyUpdate(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
// Override batch size to force multiple batches.
orig := policyMembershipDeleteBatchSize
policyMembershipDeleteBatchSize = 2
t.Cleanup(func() { policyMembershipDeleteBatchSize = orig })
// ── Part 1: platform-based cleanup ──────────────────────────────────────
// Create a windows-only policy.
pol := newTestPolicy(t, ds, user1, "batch platform cleanup", "windows", nil)
// Create 5 linux hosts (wrong platform) + 1 windows host (should remain).
winID := "batch-win-0"
winHost, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &winID,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &winID,
UUID: winID,
Hostname: winID,
Platform: "windows",
})
require.NoError(t, err)
linuxHosts := make([]*fleet.Host, 5)
for i := range linuxHosts {
id := fmt.Sprintf("batch-lin-%d", i)
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: "linux",
})
require.NoError(t, err)
linuxHosts[i] = h
}
// Record results for all hosts (simulating results arriving before platform filter applied).
allHosts := append([]*fleet.Host{winHost}, linuxHosts...)
for _, h := range allHosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
}
var count int
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
require.Equal(t, 6, count)
// Run the platform-aware cleanup (simulates CleanupPolicyMembership cron).
err = cleanupPolicyMembershipOnPolicyUpdate(ctx, ds.reader(ctx), ds.writer(ctx), pol.ID, pol.Platform)
require.NoError(t, err)
// Only the windows host should remain.
var hostIDs []uint
require.NoError(t, ds.writer(ctx).Select(&hostIDs, `SELECT host_id FROM policy_membership WHERE policy_id = ?`, pol.ID))
require.ElementsMatch(t, []uint{winHost.ID}, hostIDs)
// ── Part 2: label-based cleanup ─────────────────────────────────────────
// Create a label and a policy that targets only hosts in that label.
inclLabel, err := ds.NewLabel(ctx, &fleet.Label{Name: "batch-incl-label"})
require.NoError(t, err)
// Create 1 host that belongs to the label (should survive cleanup) and 5
// that do not (should be removed in multiple batches of 2).
lblID := "batch-lbl-0"
lblHost, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &lblID,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &lblID,
UUID: lblID,
Hostname: lblID,
Platform: "linux",
})
require.NoError(t, err)
require.NoError(t, ds.AddLabelsToHost(ctx, lblHost.ID, []uint{inclLabel.ID}))
nonLblHosts := make([]*fleet.Host, 5)
for i := range nonLblHosts {
id := fmt.Sprintf("batch-nonlbl-%d", i)
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: "linux",
})
require.NoError(t, err)
nonLblHosts[i] = h
}
// Create a label-scoped policy (no platform restriction).
lblPol := newTestPolicy(t, ds, user1, "batch label cleanup", "", nil)
lblPol.LabelsIncludeAny = []fleet.LabelIdent{{LabelName: inclLabel.Name}}
require.NoError(t, ds.SavePolicy(ctx, lblPol, false, false))
// Record policy results for all label-test hosts so policy_membership is populated.
labelHosts := append([]*fleet.Host{lblHost}, nonLblHosts...)
for _, h := range labelHosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{lblPol.ID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
}
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, lblPol.ID))
require.Equal(t, 6, count)
// Run cleanupPolicyMembershipOnPolicyUpdate with no platform restriction so
// only the label-based branch fires.
err = cleanupPolicyMembershipOnPolicyUpdate(ctx, ds.reader(ctx), ds.writer(ctx), lblPol.ID, "" /* no platform filter */)
require.NoError(t, err)
// Only the host that belongs to the include label should remain.
var lblHostIDs []uint
require.NoError(t, ds.writer(ctx).Select(&lblHostIDs, `SELECT host_id FROM policy_membership WHERE policy_id = ?`, lblPol.ID))
require.ElementsMatch(t, []uint{lblHost.ID}, lblHostIDs)
}
// testApplyPolicySpecsNeedsFullMembershipCleanupFlag verifies that:
// 1. ApplyPolicySpecs sets needs_full_membership_cleanup = 1 inside the transaction when
// the query changes (shouldRemoveAllPolicyMemberships == true).
// 2. The flag is cleared back to 0 after cleanup completes successfully.
// 3. All policy_membership rows are removed after the cleanup.
func testApplyPolicySpecsNeedsFullMembershipCleanupFlag(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
// Create the policy for the first time.
require.NoError(t, ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{Name: "flag test policy", Query: "select 1;", Platform: "", Type: fleet.PolicyTypeDynamic},
}))
// Find the policy by name so the test is not sensitive to other global policies created by concurrent tests.
pols, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
var pol *fleet.Policy
for _, p := range pols {
if p.Name == "flag test policy" {
pol = p
break
}
}
require.NotNil(t, pol, "policy 'flag test policy' not found")
// Create hosts and record failing results.
hosts := make([]*fleet.Host, 3)
for i := range hosts {
id := fmt.Sprintf("flag-test-%d", i)
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: "linux",
})
require.NoError(t, err)
hosts[i] = h
}
for _, h := range hosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
}
var count int
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
require.Equal(t, 3, count)
// Update the query — this triggers shouldRemoveAllPolicyMemberships = true.
require.NoError(t, ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{Name: "flag test policy", Query: "select 2;", Platform: "", Type: fleet.PolicyTypeDynamic},
}))
// The flag must be 0 after successful completion (set inside TX, cleared after cleanup).
var flagVal int
require.NoError(t, ds.writer(ctx).Get(&flagVal,
`SELECT needs_full_membership_cleanup FROM policies WHERE id = ?`, pol.ID))
assert.Zero(t, flagVal, "needs_full_membership_cleanup must be cleared after successful cleanup")
// All memberships must have been removed.
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
assert.Zero(t, count)
}
// testCleanupPolicyMembershipCrashRecovery verifies two recovery paths when a previous cleanup
// was interrupted (crash or error after the transaction committed):
//
// 1. GitOps retry path: ApplyPolicySpecs detects needs_full_membership_cleanup = 1 and re-runs
// the full cleanup itself, without waiting for the cron.
// 2. Cron safety net path: CleanupPolicyMembership finds needs_full_membership_cleanup = 1 and
// finishes the job when no GitOps retry occurs.
func testCleanupPolicyMembershipCrashRecovery(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
newHosts := func(t *testing.T, n int, prefix string) []*fleet.Host {
t.Helper()
hosts := make([]*fleet.Host, n)
for i := range hosts {
id := fmt.Sprintf("%s-%d", prefix, i)
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: "linux",
})
require.NoError(t, err)
hosts[i] = h
}
return hosts
}
recordResults := func(t *testing.T, hosts []*fleet.Host, polID uint) {
t.Helper()
for _, h := range hosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{polID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
}
}
t.Run("gitops retry re-triggers cleanup", func(t *testing.T) {
// Create policy via ApplyPolicySpecs so it exists in the DB.
require.NoError(t, ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{Name: "retry recovery policy", Query: "select 1;", Type: fleet.PolicyTypeDynamic},
}))
pols, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
var pol *fleet.Policy
for _, p := range pols {
if p.Name == "retry recovery policy" {
pol = p
break
}
}
require.NotNil(t, pol)
// Record membership rows.
hosts := newHosts(t, 4, "retry-recovery")
recordResults(t, hosts, pol.ID)
var count int
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
require.Equal(t, 4, count)
// Simulate: TX committed with the flag set, but cleanup never ran (crash/error).
_, err = ds.writer(ctx).ExecContext(ctx,
`UPDATE policies SET needs_full_membership_cleanup = 1 WHERE id = ?`, pol.ID)
require.NoError(t, err)
// Retry GitOps with the same spec. ApplyPolicySpecs must detect the flag and
// re-run the full cleanup — no cron needed.
require.NoError(t, ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{Name: "retry recovery policy", Query: "select 1;", Type: fleet.PolicyTypeDynamic},
}))
// Flag must be cleared by the retry.
var flagVal int
require.NoError(t, ds.writer(ctx).Get(&flagVal,
`SELECT needs_full_membership_cleanup FROM policies WHERE id = ?`, pol.ID))
assert.Zero(t, flagVal, "flag must be cleared by the GitOps retry")
// All memberships must be gone.
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
assert.Zero(t, count, "all policy_membership rows must be removed by the GitOps retry")
})
t.Run("cron cleans up when no gitops retry", func(t *testing.T) {
pol := newTestPolicy(t, ds, user1, "cron recovery policy", "", nil)
hosts := newHosts(t, 4, "cron-recovery")
recordResults(t, hosts, pol.ID)
var count int
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
require.Equal(t, 4, count)
// Simulate interrupted cleanup: set the flag directly, leave membership rows in place.
_, err := ds.writer(ctx).ExecContext(ctx,
`UPDATE policies SET needs_full_membership_cleanup = 1 WHERE id = ?`, pol.ID)
require.NoError(t, err)
// CleanupPolicyMembership (cron) should pick up the flag and run the full cleanup.
require.NoError(t, ds.CleanupPolicyMembership(ctx, time.Now()))
// Flag must be cleared.
var flagVal int
require.NoError(t, ds.writer(ctx).Get(&flagVal,
`SELECT needs_full_membership_cleanup FROM policies WHERE id = ?`, pol.ID))
assert.Zero(t, flagVal, "flag must be cleared by CleanupPolicyMembership")
// All memberships must be removed.
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
assert.Zero(t, count, "all policy_membership rows must be cleaned up by the cron safety net")
})
t.Run("cron clears flag when cleanup already completed", func(t *testing.T) {
// Simulates: the transaction committed (flag=1), cleanupPolicy ran and
// removed all membership rows, but the server crashed before executing
// UPDATE policies SET needs_full_membership_cleanup = 0.
// The cron must handle this gracefully (no-op cleanup) and clear the flag.
pol := newTestPolicy(t, ds, user1, "flag-only recovery policy", "", nil)
// No membership rows exist — simulating that cleanup already removed them.
var count int
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
require.Zero(t, count, "precondition: no membership rows")
// Set the flag to simulate the crash window between cleanup and flag clear.
_, err := ds.writer(ctx).ExecContext(ctx,
`UPDATE policies SET needs_full_membership_cleanup = 1 WHERE id = ?`, pol.ID)
require.NoError(t, err)
// CleanupPolicyMembership (cron) should handle this without errors.
require.NoError(t, ds.CleanupPolicyMembership(ctx, time.Now()))
// Flag must be cleared.
var flagVal int
require.NoError(t, ds.writer(ctx).Get(&flagVal,
`SELECT needs_full_membership_cleanup FROM policies WHERE id = ?`, pol.ID))
assert.Zero(t, flagVal, "flag must be cleared even when no membership rows remain")
})
}
func testTeamPatchPolicy(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)

File diff suppressed because one or more lines are too long