From 18e7c8e2362589d712b9b9750e20125bbf109df4 Mon Sep 17 00:00:00 2001 From: Martin Angers Date: Wed, 23 Feb 2022 16:10:37 -0500 Subject: [PATCH] Implement cron cleanup job of policy membership when policy platform is updated (#4331) --- cmd/fleet/serve.go | 5 + server/datastore/mysql/policies.go | 63 ++++++- server/datastore/mysql/policies_test.go | 224 ++++++++++++++++++++---- server/fleet/datastore.go | 2 + server/mock/datastore_mock.go | 10 ++ 5 files changed, 267 insertions(+), 37 deletions(-) diff --git a/cmd/fleet/serve.go b/cmd/fleet/serve.go index a39dbce56a..03d9f7671a 100644 --- a/cmd/fleet/serve.go +++ b/cmd/fleet/serve.go @@ -656,6 +656,11 @@ func cronCleanups(ctx context.Context, ds fleet.Datastore, logger kitlog.Logger, level.Error(logger).Log("err", "aggregating munki and mdm data", "details", err) sentry.CaptureException(err) } + err = ds.CleanupPolicyMembership(ctx, time.Now()) + if err != nil { + level.Error(logger).Log("err", "cleanup policy membership", "details", err) + sentry.CaptureException(err) + } err = trySendStatistics(ctx, ds, fleet.StatisticsFrequency, "https://fleetdm.com/api/v1/webhooks/receive-usage-analytics", license) if err != nil { diff --git a/server/datastore/mysql/policies.go b/server/datastore/mysql/policies.go index 9cd7aada86..e3d0b673e7 100644 --- a/server/datastore/mysql/policies.go +++ b/server/datastore/mysql/policies.go @@ -98,7 +98,7 @@ func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy) error { return ctxerr.Wrap(ctx, notFound("Policy").WithID(p.ID)) } - return cleanupPolicyMembership(ctx, ds.writer, p.ID, p.Platform) + return cleanupPolicyMembershipOnPolicyUpdate(ctx, ds.writer, p.ID, p.Platform) } // FlippingPoliciesForHost fetches previous policy membership results and returns: @@ -472,7 +472,7 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs // when the upsert results in an UPDATE that *did* change some values, // it returns the updated ID as last inserted id. if lastID, _ := res.LastInsertId(); lastID > 0 { - if err := cleanupPolicyMembership(ctx, tx, uint(lastID), spec.Platform); err != nil { + if err := cleanupPolicyMembershipOnPolicyUpdate(ctx, tx, uint(lastID), spec.Platform); err != nil { return err } } @@ -536,7 +536,7 @@ func (ds *Datastore) AsyncBatchUpdatePolicyTimestamp(ctx context.Context, ids [] }) } -func cleanupPolicyMembership(ctx context.Context, db sqlx.ExecerContext, policyID uint, platforms string) error { +func cleanupPolicyMembershipOnPolicyUpdate(ctx context.Context, db sqlx.ExecerContext, policyID uint, platforms string) error { if platforms == "" { // all platforms allowed, nothing to clean up return nil @@ -564,3 +564,60 @@ func cleanupPolicyMembership(ctx context.Context, db sqlx.ExecerContext, policyI _, err := db.ExecContext(ctx, delStmt, policyID, strings.Join(expandedPlatforms, ",")) return ctxerr.Wrap(ctx, err, "cleanup policy membership") } + +// CleanupPolicyMembership deletes the host's membership from policies that +// have been updated recently if those hosts don't meet the policy's criteria +// anymore (e.g. if the policy's platforms has been updated from "any" - the +// empty string - to "windows", this would delete that policy's membership rows +// for any non-windows host). +func (ds *Datastore) CleanupPolicyMembership(ctx context.Context, now time.Time) error { + const ( + recentlyUpdatedPoliciesInterval = 24 * time.Hour + + updatedPoliciesStmt = ` + SELECT + p.id, + p.platforms + FROM + policies p + WHERE + p.updated_at >= DATE_SUB(?, INTERVAL ? SECOND) AND + p.created_at < p.updated_at` // ignore newly created + + deleteMembershipStmt = ` + DELETE + pm + FROM + policy_membership pm + INNER JOIN + hosts h + ON + pm.host_id = h.id + WHERE + pm.policy_id = ? AND + FIND_IN_SET(h.platform, ?) = 0` + ) + + var pols []*fleet.Policy + if err := sqlx.SelectContext(ctx, ds.reader, &pols, updatedPoliciesStmt, now, int(recentlyUpdatedPoliciesInterval.Seconds())); err != nil { + return ctxerr.Wrap(ctx, err, "select recently updated policies") + } + + for _, pol := range pols { + if pol.Platform == "" { + continue + } + + var expandedPlatforms []string + splitPlatforms := strings.Split(pol.Platform, ",") + for _, platform := range splitPlatforms { + expandedPlatforms = append(expandedPlatforms, fleet.ExpandPlatform(strings.TrimSpace(platform))...) + } + + if _, err := ds.writer.ExecContext(ctx, deleteMembershipStmt, pol.ID, strings.Join(expandedPlatforms, ",")); err != nil { + return ctxerr.Wrapf(ctx, err, "delete outdated hosts membership for policy: %d; platforms: %v", pol.ID, expandedPlatforms) + } + } + + return nil +} diff --git a/server/datastore/mysql/policies_test.go b/server/datastore/mysql/policies_test.go index 2688edf5de..9977619bfd 100644 --- a/server/datastore/mysql/policies_test.go +++ b/server/datastore/mysql/policies_test.go @@ -42,6 +42,7 @@ func TestPolicies(t *testing.T) { {"DelUser", testPoliciesDelUser}, {"FlippingPoliciesForHost", testFlippingPoliciesForHost}, {"PlatformUpdate", testPolicyPlatformUpdate}, + {"CleanupPolicyMembership", testPolicyCleanupPolicyMembership}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -1560,43 +1561,13 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) { require.NoError(t, err) } - policyIDs := make([]uint, 0, len(polsByName)) - for _, pol := range polsByName { - policyIDs = append(policyIDs, pol.ID) - } - loadMembershipStmt, args, err := sqlx.In(`SELECT policy_id, host_id FROM policy_membership WHERE policy_id IN (?)`, policyIDs) - require.NoError(t, err) - - assertPolicyMembership := func(want map[string][]uint) { - type polHostIDs struct { - PolicyID uint `db:"policy_id"` - HostID uint `db:"host_id"` - } - var rows []polHostIDs - err := ds.writer.SelectContext(ctx, &rows, loadMembershipStmt, args...) - require.NoError(t, err) - - // index the host IDs by policy ID - hostIDsByPolID := make(map[uint][]uint, len(policyIDs)) - for _, row := range rows { - hostIDsByPolID[row.PolicyID] = append(hostIDsByPolID[row.PolicyID], row.HostID) - } - - // assert that they match the expected list of hosts by policy - for polNm, hostIDs := range want { - polID := polsByName[polNm].ID - got := hostIDsByPolID[polID] - require.ElementsMatch(t, hostIDs, got) - } - } - wantHostsByPol := map[string][]uint{ "g1": {globalHosts[hostWin].ID, globalHosts[hostMac].ID, globalHosts[hostDeb].ID, globalHosts[hostLin].ID}, "g2": {globalHosts[hostDeb].ID, globalHosts[hostLin].ID}, "t1": {teamHosts[hostWin].ID, teamHosts[hostMac].ID, teamHosts[hostDeb].ID, teamHosts[hostLin].ID}, "t2": {teamHosts[hostDeb].ID, teamHosts[hostLin].ID}, } - assertPolicyMembership(wantHostsByPol) + assertPolicyMembership(t, ds, polsByName, wantHostsByPol) // update global policy g1 from any => linux g1 := polsByName["g1"] @@ -1605,7 +1576,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) { err = ds.SavePolicy(ctx, g1) require.NoError(t, err) wantHostsByPol["g1"] = []uint{globalHosts[hostDeb].ID, globalHosts[hostLin].ID} - assertPolicyMembership(wantHostsByPol) + assertPolicyMembership(t, ds, polsByName, wantHostsByPol) // update team policy t1 from any => windows, darwin t1 := polsByName["t1"] @@ -1614,7 +1585,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) { err = ds.SavePolicy(ctx, t1) require.NoError(t, err) wantHostsByPol["t1"] = []uint{teamHosts[hostWin].ID, teamHosts[hostMac].ID} - assertPolicyMembership(wantHostsByPol) + assertPolicyMembership(t, ds, polsByName, wantHostsByPol) // update g2 from linux => any, t2 from linux => debian, via ApplySpecs t2, g2 := polsByName["t2"], polsByName["g2"] @@ -1629,5 +1600,190 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) { // nothing should've changed for g2 (platform changed to any, so nothing to cleanup), // while t2 should now only accept debian wantHostsByPol["t2"] = []uint{teamHosts[hostDeb].ID} - assertPolicyMembership(wantHostsByPol) + assertPolicyMembership(t, ds, polsByName, wantHostsByPol) +} + +func assertPolicyMembership(t *testing.T, ds *Datastore, polsByName map[string]*fleet.Policy, wantPolNameToHostIDs map[string][]uint) { + policyIDs := make([]uint, 0, len(polsByName)) + for _, pol := range polsByName { + policyIDs = append(policyIDs, pol.ID) + } + loadMembershipStmt, args, err := sqlx.In(`SELECT policy_id, host_id FROM policy_membership WHERE policy_id IN (?)`, policyIDs) + require.NoError(t, err) + + type polHostIDs struct { + PolicyID uint `db:"policy_id"` + HostID uint `db:"host_id"` + } + var rows []polHostIDs + err = ds.writer.SelectContext(context.Background(), &rows, loadMembershipStmt, args...) + require.NoError(t, err) + + // index the host IDs by policy ID + hostIDsByPolID := make(map[uint][]uint, len(policyIDs)) + for _, row := range rows { + hostIDsByPolID[row.PolicyID] = append(hostIDsByPolID[row.PolicyID], row.HostID) + } + + // assert that they match the expected list of hosts by policy + for polNm, hostIDs := range wantPolNameToHostIDs { + pol, ok := polsByName[polNm] + if !ok { + require.Len(t, hostIDs, 0) + continue + } + got := hostIDsByPolID[pol.ID] + require.ElementsMatch(t, hostIDs, got) + } +} + +func testPolicyCleanupPolicyMembership(t *testing.T, ds *Datastore) { + ctx := context.Background() + user := test.NewUser(t, ds, "Bob", "bob@example.com", true) + + // create hosts with different platforms + hostWin, hostMac, hostDeb, hostLin := 0, 1, 2, 3 + platforms := []string{"windows", "darwin", "debian", "linux"} + hosts := make([]*fleet.Host, len(platforms)) + for i, pl := range platforms { + id := fmt.Sprintf("%s-%d", strings.ReplaceAll(t.Name(), "/", "_"), i) + h, err := ds.NewHost(ctx, &fleet.Host{ + OsqueryHostID: id, + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: id, + UUID: id, + Hostname: id, + Platform: pl, + }) + require.NoError(t, err) + hosts[i] = h + } + + // create some policies, using direct insert statements to control the timestamps + createPolStmt := `INSERT INTO policies (name, query, description, author_id, platforms, created_at, updated_at) + VALUES (?, ?, '', ?, ?, ?, ?)` + + jan2020 := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + feb2020 := time.Date(2020, 2, 1, 0, 0, 0, 0, time.UTC) + mar2020 := time.Date(2020, 3, 1, 0, 0, 0, 0, time.UTC) + apr2020 := time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC) + may2020 := time.Date(2020, 5, 1, 0, 0, 0, 0, time.UTC) + pols := make([]*fleet.Policy, 3) + for i, dt := range []time.Time{jan2020, feb2020, mar2020} { + res, err := ds.writer.ExecContext(ctx, createPolStmt, "p"+strconv.Itoa(i+1), "select 1", user.ID, "", dt, dt) + require.NoError(t, err) + id, _ := res.LastInsertId() + pol, err := ds.Policy(ctx, uint(id)) + require.NoError(t, err) + pols[i] = pol + } + // index the policies by name for easier access in the rest of the test + polsByName := make(map[string]*fleet.Policy, len(pols)) + for _, pol := range pols { + polsByName[pol.Name] = pol + } + + wantHostsByPol := map[string][]uint{ + "p1": {}, + "p2": {}, + "p3": {}, + } + // no recently updated policies + err := ds.CleanupPolicyMembership(ctx, time.Now()) + require.NoError(t, err) + assertPolicyMembership(t, ds, polsByName, wantHostsByPol) + + // record results for each policy, all hosts, even if invalid for the policy + for _, h := range hosts { + res := map[uint]*bool{ + polsByName["p1"].ID: ptr.Bool(true), + polsByName["p2"].ID: ptr.Bool(true), + polsByName["p3"].ID: ptr.Bool(true), + } + err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false) + require.NoError(t, err) + } + + // no recently updated policies, so no host gets cleaned up + wantHostsByPol = map[string][]uint{ + "p1": {hosts[hostWin].ID, hosts[hostMac].ID, hosts[hostDeb].ID, hosts[hostLin].ID}, + "p2": {hosts[hostWin].ID, hosts[hostMac].ID, hosts[hostDeb].ID, hosts[hostLin].ID}, + "p3": {hosts[hostWin].ID, hosts[hostMac].ID, hosts[hostDeb].ID, hosts[hostLin].ID}, + } + err = ds.CleanupPolicyMembership(ctx, time.Now()) + require.NoError(t, err) + assertPolicyMembership(t, ds, polsByName, wantHostsByPol) + + // update policy p1, but do not change the platform (still any) + pols[0].Description = "updated" + updatePolicyWithTimestamp(t, ds, pols[0], feb2020) + err = ds.CleanupPolicyMembership(ctx, time.Now()) + require.NoError(t, err) + assertPolicyMembership(t, ds, polsByName, wantHostsByPol) + + // update policy p1 to "windows", but cleanup with a timestamp of apr2020, so + // not "recently updated", no changes + pols[0].Platform = "windows" + updatePolicyWithTimestamp(t, ds, pols[0], mar2020) + err = ds.CleanupPolicyMembership(ctx, apr2020) + require.NoError(t, err) + assertPolicyMembership(t, ds, polsByName, wantHostsByPol) + + // now cleanup with a timestamp of mar2020+1h, so "recently updated", only windows + // hosts are kept + err = ds.CleanupPolicyMembership(ctx, mar2020.Add(time.Hour)) + require.NoError(t, err) + wantHostsByPol["p1"] = []uint{hosts[hostWin].ID} + assertPolicyMembership(t, ds, polsByName, wantHostsByPol) + + // update policy p2 to "linux,darwin", but cleanup with a timestamp of just over 24h, so + // not "recently updated", no changes + pols[1].Platform = "linux,darwin" + updatePolicyWithTimestamp(t, ds, pols[1], mar2020) + err = ds.CleanupPolicyMembership(ctx, mar2020.Add(25*time.Hour)) + require.NoError(t, err) + assertPolicyMembership(t, ds, polsByName, wantHostsByPol) + + // now cleanup with a timestamp of just under 24h, so it is "recently updated" + err = ds.CleanupPolicyMembership(ctx, mar2020.Add(23*time.Hour)) + require.NoError(t, err) + wantHostsByPol["p2"] = []uint{hosts[hostMac].ID, hosts[hostDeb].ID, hosts[hostLin].ID} + assertPolicyMembership(t, ds, polsByName, wantHostsByPol) + + // update policy p2 to just "linux", p3 to "debian", both get cleaned up (using apr2020 + // because p3 was created with mar2020, so it will not be detected as updated if we use + // that same timestamp for the update). + pols[1].Platform = "linux" + updatePolicyWithTimestamp(t, ds, pols[1], apr2020) + pols[2].Platform = "debian" + updatePolicyWithTimestamp(t, ds, pols[2], apr2020) + err = ds.CleanupPolicyMembership(ctx, apr2020.Add(time.Hour)) + require.NoError(t, err) + wantHostsByPol["p2"] = []uint{hosts[hostDeb].ID, hosts[hostLin].ID} + wantHostsByPol["p3"] = []uint{hosts[hostDeb].ID} + assertPolicyMembership(t, ds, polsByName, wantHostsByPol) + + // cleaning up again 1h later doesn't change anything + err = ds.CleanupPolicyMembership(ctx, apr2020.Add(2*time.Hour)) + require.NoError(t, err) + assertPolicyMembership(t, ds, polsByName, wantHostsByPol) + + // update policy p1 to allow any, doesn't clean up anything + pols[0].Platform = "" + updatePolicyWithTimestamp(t, ds, pols[0], may2020) + err = ds.CleanupPolicyMembership(ctx, may2020.Add(time.Hour)) + require.NoError(t, err) + assertPolicyMembership(t, ds, polsByName, wantHostsByPol) +} + +func updatePolicyWithTimestamp(t *testing.T, ds *Datastore, p *fleet.Policy, ts time.Time) { + sql := ` + UPDATE policies + SET name = ?, query = ?, description = ?, resolution = ?, platforms = ?, updated_at = ? + WHERE id = ?` + _, err := ds.writer.ExecContext(context.Background(), sql, p.Name, p.Query, p.Description, p.Resolution, p.Platform, ts, p.ID) + require.NoError(t, err) } diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 2bc9772173..3dfc42c98b 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -398,6 +398,8 @@ type Datastore interface { DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error) TeamPolicy(ctx context.Context, teamID uint, policyID uint) (*Policy, error) + CleanupPolicyMembership(ctx context.Context, now time.Time) error + /////////////////////////////////////////////////////////////////////////////// // Locking diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 801425065d..a350b5b933 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -326,6 +326,8 @@ type DeleteTeamPoliciesFunc func(ctx context.Context, teamID uint, ids []uint) ( type TeamPolicyFunc func(ctx context.Context, teamID uint, policyID uint) (*fleet.Policy, error) +type CleanupPolicyMembershipFunc func(ctx context.Context, now time.Time) error + type LockFunc func(ctx context.Context, name string, owner string, expiration time.Duration) (bool, error) type UnlockFunc func(ctx context.Context, name string, owner string) error @@ -852,6 +854,9 @@ type DataStore struct { TeamPolicyFunc TeamPolicyFunc TeamPolicyFuncInvoked bool + CleanupPolicyMembershipFunc CleanupPolicyMembershipFunc + CleanupPolicyMembershipFuncInvoked bool + LockFunc LockFunc LockFuncInvoked bool @@ -1719,6 +1724,11 @@ func (s *DataStore) TeamPolicy(ctx context.Context, teamID uint, policyID uint) return s.TeamPolicyFunc(ctx, teamID, policyID) } +func (s *DataStore) CleanupPolicyMembership(ctx context.Context, now time.Time) error { + s.CleanupPolicyMembershipFuncInvoked = true + return s.CleanupPolicyMembershipFunc(ctx, now) +} + func (s *DataStore) Lock(ctx context.Context, name string, owner string, expiration time.Duration) (bool, error) { s.LockFuncInvoked = true return s.LockFunc(ctx, name, owner, expiration)