From d0f0d3d017c0a11717b5595e93b74458d8f403cf Mon Sep 17 00:00:00 2001 From: Victor Lyuboslavsky Date: Mon, 29 Apr 2024 10:20:59 -0500 Subject: [PATCH] When updating a policy's 'platform' field, the aggregated policy stats are now cleared. (#18415) #18157 When updating a policy's 'platform' field, the aggregated policy stats are now cleared. # Checklist for submitter - [x] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - [x] Added/updated tests - [x] Manual QA for all new/changed functionality --- changes/18157-update-platform-policy-stats | 1 + server/datastore/mysql/policies.go | 176 +++++++++--- server/datastore/mysql/policies_test.go | 320 ++++++++++++++++++++- server/fleet/datastore.go | 2 +- server/mock/datastore_mock.go | 6 +- server/service/global_policies_test.go | 2 +- server/service/integration_core_test.go | 59 ++++ server/service/team_policies.go | 17 +- server/service/team_policies_test.go | 2 +- 9 files changed, 527 insertions(+), 58 deletions(-) create mode 100644 changes/18157-update-platform-policy-stats diff --git a/changes/18157-update-platform-policy-stats b/changes/18157-update-platform-policy-stats new file mode 100644 index 0000000000..fdaa87d56d --- /dev/null +++ b/changes/18157-update-platform-policy-stats @@ -0,0 +1 @@ +When updating a policy's 'platform' field, the aggregated policy stats are now cleared. diff --git a/server/datastore/mysql/policies.go b/server/datastore/mysql/policies.go index d80dd2d60b..1ad7ef8f47 100644 --- a/server/datastore/mysql/policies.go +++ b/server/datastore/mysql/policies.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "sort" "strings" @@ -14,6 +15,7 @@ import ( "github.com/doug-martin/goqu/v9" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" + kitlog "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" "github.com/jmoiron/sqlx" ) @@ -110,8 +112,8 @@ func policyDB(ctx context.Context, q sqlx.QueryerContext, id uint, teamID *uint) // SavePolicy updates some fields of the given policy on the datastore. // -// Currently SavePolicy does not allow updating the team of an existing policy. -func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error { +// Currently, SavePolicy does not allow updating the team of an existing policy. +func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool, removePolicyStats bool) error { // We must normalize the name for full Unicode support (Unicode equivalence). p.Name = norm.NFC.String(p.Name) sql := ` @@ -133,10 +135,39 @@ func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemo return ctxerr.Wrap(ctx, notFound("Policy").WithID(p.ID)) } + return cleanupPolicy(ctx, ds.writer(ctx), p.ID, p.Platform, shouldRemoveAllPolicyMemberships, removePolicyStats, ds.logger) +} + +func cleanupPolicy( + ctx context.Context, extContext sqlx.ExtContext, policyID uint, policyPlatform string, shouldRemoveAllPolicyMemberships bool, + removePolicyStats bool, logger kitlog.Logger, +) error { + var err error if shouldRemoveAllPolicyMemberships { - return ds.cleanupPolicyMembershipForPolicy(ctx, p.ID) + err = cleanupPolicyMembershipForPolicy(ctx, extContext, policyID) + } else { + err = cleanupPolicyMembershipOnPolicyUpdate(ctx, extContext, policyID, policyPlatform) } - return cleanupPolicyMembershipOnPolicyUpdate(ctx, ds.writer(ctx), p.ID, p.Platform) + if err != nil { + return err + } + if removePolicyStats { + // delete all policy stats for the policy + fn := func(tx sqlx.ExtContext) error { + _, err := tx.ExecContext(ctx, `DELETE FROM policy_stats WHERE policy_id = ?`, policyID) + return err + } + if _, isDB := extContext.(*sqlx.DB); isDB { + // wrapping in a retry to avoid deadlocks with the cleanups_then_aggregation cron job + err = withRetryTxx(ctx, extContext.(*sqlx.DB), fn, logger) + } else { + err = fn(extContext) + } + if err != nil { + return ctxerr.Wrap(ctx, err, "cleanup policy stats") + } + } + return nil } // FlippingPoliciesForHost fetches previous policy membership results and returns: @@ -576,8 +607,74 @@ func (ds *Datastore) TeamPolicy(ctx context.Context, teamID uint, policyID uint) // NOTE: Similar to ApplyQueries, ApplyPolicySpecs will update the author_id of the policies // that are updated. // -// Currently ApplyPolicySpecs does not allow updating the team of an existing policy. +// Currently, ApplyPolicySpecs does not allow updating the team of an existing policy. func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs []*fleet.PolicySpec) error { + // Use the same DB for all operations in this method for performance + queryerContext := ds.writer(ctx) + + // Preprocess specs and group them by team + teamNameToID := make(map[string]uint, 1) + teamIDToPolicies := make(map[uint][]*fleet.PolicySpec, 1) + + // Get the team IDs + for _, spec := range specs { + // We must normalize the name for full Unicode support (Unicode equivalence). + spec.Name = norm.NFC.String(spec.Name) + spec.Team = norm.NFC.String(spec.Team) + teamID, ok := teamNameToID[spec.Team] + if !ok { + if spec.Team != "" { + // if team name is not empty, it must have a team ID; otherwise teamID defaults to 0 value + err := sqlx.GetContext(ctx, queryerContext, &teamID, `SELECT id FROM teams WHERE name = ?`, spec.Team) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return ctxerr.Wrap(ctx, notFound("Team").WithName(spec.Team), "get team id") + } + return ctxerr.Wrap(ctx, err, "get team id") + } + } + teamNameToID[spec.Team] = teamID + } + teamIDToPolicies[teamID] = append(teamIDToPolicies[teamID], spec) + } + + // Get the query and platforms of the current policies so that we can check if query or platform changed later, if needed + type policyLite struct { + Name string `db:"name"` + Query string `db:"query"` + Platforms string `db:"platforms"` + } + teamIDToPoliciesByName := make(map[uint]map[string]policyLite, len(teamIDToPolicies)) + for teamID, teamPolicySpecs := range teamIDToPolicies { + teamIDToPoliciesByName[teamID] = make(map[string]policyLite, len(teamPolicySpecs)) + policyNames := make([]string, 0, len(teamPolicySpecs)) + for _, spec := range teamPolicySpecs { + policyNames = append(policyNames, spec.Name) + } + + var query string + var args []interface{} + var err error + if teamID == 0 { + query, args, err = sqlx.In("SELECT name, query, platforms FROM policies WHERE team_id IS NULL AND name IN (?)", policyNames) + } else { + query, args, err = sqlx.In( + "SELECT name, query, platforms FROM policies WHERE team_id = ? AND name IN (?)", &teamID, policyNames, + ) + } + if err != nil { + return ctxerr.Wrap(ctx, err, "building query to get policies by name") + } + policies := make([]policyLite, 0, len(teamPolicySpecs)) + err = sqlx.SelectContext(ctx, queryerContext, &policies, query, args...) + if err != nil { + return ctxerr.Wrap(ctx, err, "getting policies by name") + } + for _, p := range policies { + teamIDToPoliciesByName[teamID][p.Name] = p + } + } + return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { query := fmt.Sprintf( ` @@ -592,7 +689,7 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs critical, calendar_events_enabled, checksum - ) VALUES ( ?, ?, ?, ?, ?, (SELECT IFNULL(MIN(id), NULL) FROM teams WHERE name = ?), ?, ?, ?, %s) + ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, %s) ON DUPLICATE KEY UPDATE query = VALUES(query), description = VALUES(description), @@ -603,24 +700,45 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs calendar_events_enabled = VALUES(calendar_events_enabled) `, policiesChecksumComputedColumn(), ) - for _, spec := range specs { - - // We must normalize the name for full Unicode support (Unicode equivalence). - spec.Name = norm.NFC.String(spec.Name) - res, err := tx.ExecContext(ctx, - query, spec.Name, spec.Query, spec.Description, authorID, spec.Resolution, spec.Team, spec.Platform, spec.Critical, - spec.CalendarEventsEnabled, - ) - if err != nil { - return ctxerr.Wrap(ctx, err, "exec ApplyPolicySpecs insert") + for teamID, teamPolicySpecs := range teamIDToPolicies { + var teamIDPtr *uint + if teamID != 0 { + teamIDPtr = &teamID } + for _, spec := range teamPolicySpecs { - if insertOnDuplicateDidUpdate(res) { - // when the upsert results in an UPDATE that *did* change some values, - // it returns the updated ID as last inserted id. - if lastID, _ := res.LastInsertId(); lastID > 0 { - if err := cleanupPolicyMembershipOnPolicyUpdate(ctx, tx, uint(lastID), spec.Platform); err != nil { - return err + res, err := tx.ExecContext( + ctx, + query, spec.Name, spec.Query, spec.Description, authorID, spec.Resolution, teamIDPtr, spec.Platform, spec.Critical, + spec.CalendarEventsEnabled, + ) + if err != nil { + return ctxerr.Wrap(ctx, err, "exec ApplyPolicySpecs insert") + } + + if insertOnDuplicateDidUpdate(res) { + // when the upsert results in an UPDATE that *did* change some values, + // it returns the updated ID as last inserted id. + if lastID, _ := res.LastInsertId(); lastID > 0 { + var ( + shouldRemoveAllPolicyMemberships bool + removePolicyStats bool + ) + // Figure out if the query or platform changed + if prev, ok := teamIDToPoliciesByName[teamID][spec.Name]; ok { + switch { + case prev.Query != spec.Query: + shouldRemoveAllPolicyMemberships = true + removePolicyStats = true + case prev.Platforms != spec.Platform: + removePolicyStats = true + } + } + if err = cleanupPolicy( + ctx, tx, uint(lastID), spec.Platform, shouldRemoveAllPolicyMemberships, removePolicyStats, ds.logger, + ); err != nil { + return err + } } } } @@ -739,7 +857,7 @@ func cleanupPolicyMembershipOnPolicyUpdate(ctx context.Context, db sqlx.ExecerCo // cleanupPolicyMembership is similar to cleanupPolicyMembershipOnPolicyUpdate but without the platform constraints. // Used when we want to remove all policy membership. -func (ds *Datastore) cleanupPolicyMembershipForPolicy(ctx context.Context, policyID uint) error { +func cleanupPolicyMembershipForPolicy(ctx context.Context, exec sqlx.ExecerContext, policyID uint) error { // delete all policy memberships for the policy delStmt := ` DELETE @@ -754,21 +872,11 @@ func (ds *Datastore) cleanupPolicyMembershipForPolicy(ctx context.Context, polic pm.policy_id = ? ` - _, err := ds.writer(ctx).ExecContext(ctx, delStmt, policyID) + _, err := exec.ExecContext(ctx, delStmt, policyID) if err != nil { return ctxerr.Wrap(ctx, err, "cleanup policy membership") } - // delete all policy stats for the policy - // wrapping in a retry to avoid deadlocks with the cleanups_then_aggregation cron job - err = ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { - _, err := tx.ExecContext(ctx, `DELETE FROM policy_stats WHERE policy_id = ?`, policyID) - return err - }) - if err != nil { - return ctxerr.Wrap(ctx, err, "cleanup policy stats") - } - return nil } diff --git a/server/datastore/mysql/policies_test.go b/server/datastore/mysql/policies_test.go index 045d850fa1..90d9015489 100644 --- a/server/datastore/mysql/policies_test.go +++ b/server/datastore/mysql/policies_test.go @@ -40,6 +40,7 @@ func TestPolicies(t *testing.T) { {"PoliciesByID", testPoliciesByID}, {"TeamPolicyTransfer", testTeamPolicyTransfer}, {"ApplyPolicySpec", testApplyPolicySpec}, + {"ApplyPolicySpecWithQueryPlatformChanges", testApplyPolicySpecWithQueryPlatformChanges}, {"Save", testPoliciesSave}, {"DelUser", testPoliciesDelUser}, {"FlippingPoliciesForHost", testFlippingPoliciesForHost}, @@ -1400,6 +1401,298 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) { })) } +func testApplyPolicySpecWithQueryPlatformChanges(t *testing.T, ds *Datastore) { + ctx := context.Background() + unicode, _ := strconv.Unquote(`"\uAC00"`) // 가 + unicodeEq, _ := strconv.Unquote(`"\u1100\u1161"`) // ᄀ + ᅡ + + user1 := test.NewUser(t, ds, "User1", "user1@example.com", true) + team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1" + unicode}) + require.NoError(t, err) + + globalNames := []string{"global query1" + unicode, "global query2" + unicode, "global query3" + unicode} + teamNames := []string{"team query1", "team query2", "team query3"} + require.NoError( + t, ds.ApplyPolicySpecs( + ctx, user1.ID, []*fleet.PolicySpec{ + { + Name: globalNames[0], + Query: "select 1;", + Team: "", + Platform: "", + }, + { + Name: globalNames[1], + Query: "select 2;", + Team: "", + Platform: "darwin", + }, + { + Name: globalNames[2], + Query: "select 3;", + Team: "", + Platform: "darwin,linux", + }, + { + Name: teamNames[0], + Query: "select 1;", + Team: "team1" + unicode, + Platform: "", + }, + { + Name: teamNames[1], + Query: "select 2;", + Team: "team1" + unicode, + Platform: "darwin", + }, + { + Name: teamNames[2], + Query: "select 3;", + Team: "team1" + unicodeEq, + Platform: "darwin,linux", + }, + }, + ), + ) + + // create hosts with different platforms, for that team + const hostWin, hostMac, hostDeb, hostLin = 0, 1, 2, 3 + platforms := []string{"windows", "darwin", "debian", "linux"} + teamHosts := make([]*fleet.Host, len(platforms)) + for i, pl := range platforms { + id := fmt.Sprintf("%s-%d", strings.ReplaceAll(t.Name(), "/", "_"), i) + h, err := ds.NewHost( + ctx, &fleet.Host{ + OsqueryHostID: &id, + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: &id, + UUID: id, + Hostname: id, + Platform: pl, + TeamID: ptr.Uint(team1.ID), + }, + ) + require.NoError(t, err) + teamHosts[i] = h + } + + // create hosts with different platforms, without team + globalHosts := make([]*fleet.Host, len(platforms)) + for i, pl := range platforms { + id := fmt.Sprintf("g%s-%d", strings.ReplaceAll(t.Name(), "/", "_"), i) + h, err := ds.NewHost( + ctx, &fleet.Host{ + OsqueryHostID: &id, + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: &id, + UUID: id, + Hostname: id, + Platform: pl, + }, + ) + require.NoError(t, err) + globalHosts[i] = h + } + + // load the global policies + gPolicies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{}) + require.NoError(t, err) + require.Len(t, gPolicies, 3) + // load the team policies + tPolicies, _, err := ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}) + require.NoError(t, err) + require.Len(t, tPolicies, 3) + + // index the policies by name for easier access in the rest of the test + polsByName := make(map[string]*fleet.Policy, len(gPolicies)+len(tPolicies)) + globalPolsByName := make(map[string]*fleet.Policy, len(gPolicies)) + for _, pol := range tPolicies { + polsByName[pol.Name] = pol + } + for _, pol := range gPolicies { + globalPolsByName[pol.Name] = pol + polsByName[pol.Name] = pol + } + + // record some results for each policy + // Note: we are adding results to hosts that shouldn't have results, based on their platform. + for _, h := range teamHosts { + res := make(map[uint]*bool, len(polsByName)) + for _, pol := range polsByName { + res[pol.ID] = ptr.Bool(false) + } + err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false) + require.NoError(t, err) + } + for _, h := range globalHosts { + res := make(map[uint]*bool, len(globalPolsByName)) + for _, pol := range globalPolsByName { + res[pol.ID] = ptr.Bool(false) + } + err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false) + require.NoError(t, err) + } + err = ds.UpdateHostPolicyCounts(ctx) + require.NoError(t, err) + + // Update host failure counts and ensure they are correct + teamHosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, teamHosts) + require.NoError(t, err) + assert.Equal(t, 6, teamHosts[hostWin].FailingPoliciesCount) + assert.Equal(t, 6, teamHosts[hostMac].FailingPoliciesCount) + assert.Equal(t, 6, teamHosts[hostDeb].FailingPoliciesCount) + assert.Equal(t, 6, teamHosts[hostLin].FailingPoliciesCount) + globalHosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, globalHosts) + require.NoError(t, err) + assert.Equal(t, 3, globalHosts[hostWin].FailingPoliciesCount) + assert.Equal(t, 3, globalHosts[hostMac].FailingPoliciesCount) + assert.Equal(t, 3, globalHosts[hostDeb].FailingPoliciesCount) + assert.Equal(t, 3, globalHosts[hostLin].FailingPoliciesCount) + + // Ensure policy passing and failing counts are correct + gPolicies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{}) + require.NoError(t, err) + require.Len(t, gPolicies, 3) + tPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}) + require.NoError(t, err) + require.Len(t, tPolicies, 3) + + for _, pol := range gPolicies { + polsByName[pol.Name] = pol + } + for _, pol := range tPolicies { + polsByName[pol.Name] = pol + } + assert.Equal(t, uint(8), polsByName[globalNames[0]].FailingHostCount) + assert.Equal(t, uint(8), polsByName[globalNames[1]].FailingHostCount) + assert.Equal(t, uint(8), polsByName[globalNames[2]].FailingHostCount) + assert.Equal(t, uint(4), polsByName[teamNames[0]].FailingHostCount) + assert.Equal(t, uint(4), polsByName[teamNames[1]].FailingHostCount) + assert.Equal(t, uint(4), polsByName[teamNames[2]].FailingHostCount) + + // Update policies + require.NoError( + t, ds.ApplyPolicySpecs( + ctx, user1.ID, []*fleet.PolicySpec{ + { + Name: globalNames[0], + Query: "select 1;", + Team: "", + Platform: "", + Description: "updated", // update description + }, + { + Name: globalNames[1], + Query: "select 2 updated;", // update query + Team: "", + Platform: "darwin", + }, + { + Name: globalNames[2], + Query: "select 3;", + Team: "", + Platform: "darwin", // update platform + }, + { + Name: "new global query", + Query: "select 4;", + Team: "", + Platform: "", + }, + { + Name: teamNames[0], + Query: "select 1;", + Team: "team1" + unicode, + Platform: "linux", // update platform + }, + { + Name: teamNames[1], + Query: "select 2;", + Team: "team1" + unicode, + Platform: "darwin", + CalendarEventsEnabled: true, // update calendar events + }, + { + Name: teamNames[2], + Query: "select 3 updated;", // update query + Team: "team1" + unicodeEq, + Platform: "darwin,linux", + }, + { + Name: "new team query", + Query: "select 4;", + Team: "team1" + unicode, + Platform: "", + }, + }, + ), + ) + + // Update host failure counts and ensure they are correct + teamHosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, teamHosts) + require.NoError(t, err) + assert.Equal(t, 1, teamHosts[hostWin].FailingPoliciesCount) // kept result from globalNames[0] + assert.Equal(t, 3, teamHosts[hostMac].FailingPoliciesCount) + assert.Equal(t, 2, teamHosts[hostDeb].FailingPoliciesCount) + assert.Equal(t, 2, teamHosts[hostLin].FailingPoliciesCount) + globalHosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, globalHosts) + require.NoError(t, err) + assert.Equal(t, 1, globalHosts[hostWin].FailingPoliciesCount) + assert.Equal(t, 2, globalHosts[hostMac].FailingPoliciesCount) + assert.Equal(t, 1, globalHosts[hostDeb].FailingPoliciesCount) + assert.Equal(t, 1, globalHosts[hostLin].FailingPoliciesCount) + + // Ensure policy passing and failing counts are correct + gPolicies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{}) + require.NoError(t, err) + require.Len(t, gPolicies, 4) + tPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}) + require.NoError(t, err) + require.Len(t, tPolicies, 4) + + for _, pol := range gPolicies { + polsByName[pol.Name] = pol + } + for _, pol := range tPolicies { + polsByName[pol.Name] = pol + } + assert.Equal(t, uint(8), polsByName[globalNames[0]].FailingHostCount) + assert.Equal(t, uint(0), polsByName[globalNames[1]].FailingHostCount) // updated query + assert.Equal(t, uint(0), polsByName[globalNames[2]].FailingHostCount) // updated platform + assert.Equal(t, uint(0), polsByName[teamNames[0]].FailingHostCount) // updated platform + assert.Equal(t, uint(4), polsByName[teamNames[1]].FailingHostCount) + assert.Equal(t, uint(0), polsByName[teamNames[2]].FailingHostCount) // updated query + + err = ds.UpdateHostPolicyCounts(ctx) + require.NoError(t, err) + gPolicies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{}) + require.NoError(t, err) + require.Len(t, gPolicies, 4) + tPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}) + require.NoError(t, err) + require.Len(t, tPolicies, 4) + + for _, pol := range gPolicies { + polsByName[pol.Name] = pol + } + for _, pol := range tPolicies { + polsByName[pol.Name] = pol + } + assert.Equal(t, uint(8), polsByName[globalNames[0]].FailingHostCount) // platform is "" -- no change + assert.Equal(t, uint(0), polsByName[globalNames[1]].FailingHostCount) // updated query + assert.Equal(t, uint(2), polsByName[globalNames[2]].FailingHostCount) // updated platform + assert.Equal(t, uint(2), polsByName[teamNames[0]].FailingHostCount) // updated platform + assert.Equal(t, uint(1), polsByName[teamNames[1]].FailingHostCount) // platform is "darwin" -- no change + assert.Equal(t, uint(0), polsByName[teamNames[2]].FailingHostCount) // updated query + +} + func testPoliciesSave(t *testing.T, ds *Datastore) { user1 := test.NewUser(t, ds, "User1", "user1@example.com", true) ctx := context.Background() @@ -1412,7 +1705,8 @@ func testPoliciesSave(t *testing.T, ds *Datastore) { Name: "non-existent query", Query: "select 1;", }, - }, false) + }, false, false, + ) require.Error(t, err) var nfe *notFoundError require.True(t, errors.As(err, &nfe)) @@ -1473,7 +1767,7 @@ func testPoliciesSave(t *testing.T, ds *Datastore) { gp2 := *gp gp2.Name = "global query updated" gp2.Critical = true - err = ds.SavePolicy(ctx, &gp2, false) + err = ds.SavePolicy(ctx, &gp2, false, false) require.NoError(t, err) gp, err = ds.Policy(ctx, gp.ID) require.NoError(t, err) @@ -1493,7 +1787,7 @@ func testPoliciesSave(t *testing.T, ds *Datastore) { tp2.Resolution = ptr.String("team1 query resolution updated") tp2.Critical = false tp2.CalendarEventsEnabled = false - err = ds.SavePolicy(ctx, &tp2, true) + err = ds.SavePolicy(ctx, &tp2, true, true) require.NoError(t, err) tp1, err = ds.Policy(ctx, tp1.ID) tp2.UpdateCreateTimestamps = tp1.UpdateCreateTimestamps @@ -1586,7 +1880,7 @@ func testCachedPolicyCountDeletesOnPolicyChange(t *testing.T, ds *Datastore) { assert.Equal(t, uint(1), inheritedPolicies[0].PassingHostCount) // Update the global policy sql to trigger a cache invalidation - err = ds.SavePolicy(ctx, globalPolicy, true) + err = ds.SavePolicy(ctx, globalPolicy, true, true) require.NoError(t, err) globalPolicy, err = ds.Policy(ctx, globalPolicy.ID) @@ -1599,8 +1893,8 @@ func testCachedPolicyCountDeletesOnPolicyChange(t *testing.T, ds *Datastore) { assert.Equal(t, uint(1), teamPolicies[0].PassingHostCount) assert.Equal(t, uint(0), inheritedPolicies[0].PassingHostCount) - // Update the team policy sql to trigger a cache invalidation - err = ds.SavePolicy(ctx, teamPolicy, true) + // Update the team policy platform to trigger a cache invalidation + err = ds.SavePolicy(ctx, teamPolicy, false, true) require.NoError(t, err) teamPolicies, inheritedPolicies, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}) @@ -1921,9 +2215,9 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) { } // updating without change works fine - err = ds.SavePolicy(ctx, polsByName["g1"], false) + err = ds.SavePolicy(ctx, polsByName["g1"], false, false) require.NoError(t, err) - err = ds.SavePolicy(ctx, polsByName["t2"], false) + err = ds.SavePolicy(ctx, polsByName["t2"], false, false) require.NoError(t, err) // apply specs that result in an update (without change) works fine err = ds.ApplyPolicySpecs(ctx, user.ID, []*fleet.PolicySpec{ @@ -1975,7 +2269,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) { g1 := polsByName["g1"] g1.Platform = "linux" polsByName["g1"] = g1 - err = ds.SavePolicy(ctx, g1, false) + err = ds.SavePolicy(ctx, g1, false, false) require.NoError(t, err) wantHostsByPol["g1"] = []uint{globalHosts[hostDeb].ID, globalHosts[hostLin].ID} assertPolicyMembership(t, ds, polsByName, wantHostsByPol) @@ -1984,7 +2278,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) { t1 := polsByName["t1"] t1.Platform = "windows,darwin" polsByName["t1"] = t1 - err = ds.SavePolicy(ctx, t1, false) + err = ds.SavePolicy(ctx, t1, false, false) require.NoError(t, err) wantHostsByPol["t1"] = []uint{teamHosts[hostWin].ID, teamHosts[hostMac].ID} assertPolicyMembership(t, ds, polsByName, wantHostsByPol) @@ -2723,7 +3017,7 @@ func testPoliciesNameUnicode(t *testing.T, ds *Datastore) { policyEmoji, err := ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: "💻"}) require.NoError(t, err) err = ds.SavePolicy( - context.Background(), &fleet.Policy{PolicyData: fleet.PolicyData{ID: policyEmoji.ID, Name: equivalentNames[1]}}, false, + context.Background(), &fleet.Policy{PolicyData: fleet.PolicyData{ID: policyEmoji.ID, Name: equivalentNames[1]}}, false, false, ) assert.True(t, isDuplicate(err), err) @@ -3270,10 +3564,10 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) { // team2Policy1.Platform = "darwin" - err = ds.SavePolicy(ctx, team1Policy1, false) + err = ds.SavePolicy(ctx, team1Policy1, false, true) require.NoError(t, err) team1Policy1.Platform = "darwin" - err = ds.SavePolicy(ctx, team2Policy1, false) + err = ds.SavePolicy(ctx, team2Policy1, false, true) require.NoError(t, err) // diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 764552802c..f8cc5050ca 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -602,7 +602,7 @@ type Datastore interface { // SavePolicy updates some fields of the given policy on the datastore. // // It is also used to update team policies. - SavePolicy(ctx context.Context, p *Policy, shouldRemoveAllPolicyMemberships bool) error + SavePolicy(ctx context.Context, p *Policy, shouldRemoveAllPolicyMemberships bool, removePolicyStats bool) error ListGlobalPolicies(ctx context.Context, opts ListOptions) ([]*Policy, error) PoliciesByID(ctx context.Context, ids []uint) (map[uint]*Policy, error) diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index a1211c4f74..07260d7d03 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -435,7 +435,7 @@ type NewGlobalPolicyFunc func(ctx context.Context, authorID *uint, args fleet.Po type PolicyFunc func(ctx context.Context, id uint) (*fleet.Policy, error) -type SavePolicyFunc func(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error +type SavePolicyFunc func(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool, removePolicyStats bool) error type ListGlobalPoliciesFunc func(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error) @@ -3729,11 +3729,11 @@ func (s *DataStore) Policy(ctx context.Context, id uint) (*fleet.Policy, error) return s.PolicyFunc(ctx, id) } -func (s *DataStore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error { +func (s *DataStore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool, removePolicyStats bool) error { s.mu.Lock() s.SavePolicyFuncInvoked = true s.mu.Unlock() - return s.SavePolicyFunc(ctx, p, shouldRemoveAllPolicyMemberships) + return s.SavePolicyFunc(ctx, p, shouldRemoveAllPolicyMemberships, removePolicyStats) } func (s *DataStore) ListGlobalPolicies(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error) { diff --git a/server/service/global_policies_test.go b/server/service/global_policies_test.go index 615c0600f4..108de7c70b 100644 --- a/server/service/global_policies_test.go +++ b/server/service/global_policies_test.go @@ -68,7 +68,7 @@ func TestGlobalPoliciesAuth(t *testing.T) { ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error { return nil } - ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool) error { + ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool, removePolicyStats bool) error { return nil } ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 88742e44df..b87b228ab5 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -2428,6 +2428,65 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() { assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount) assert.Equal(t, uint(0), policiesResponse.Policies[0].PassingHostCount) + // Record query executions + require.NoError( + t, s.ds.RecordPolicyQueryExecutions( + context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false, + ), + ) + require.NoError( + t, s.ds.RecordPolicyQueryExecutions( + context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false, + ), + ) + // Update policy stats + require.NoError(t, s.ds.UpdateHostPolicyCounts(context.Background())) + + // Fetch policy to make sure stats are updated + s.DoJSON("GET", "/api/latest/fleet/policies", nil, http.StatusOK, &policiesResponse) + require.Len(t, policiesResponse.Policies, 1) + assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount) + assert.Equal(t, uint(1), policiesResponse.Policies[0].PassingHostCount) + + listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID) + listHostsResp = listHostsResponse{} + s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) + require.Len(t, listHostsResp.Hosts, 1) + + // Modify the platform for the policy, which should clear the policy stats + mgpParams = modifyGlobalPolicyRequest{ + ModifyPolicyPayload: fleet.ModifyPolicyPayload{ + Platform: ptr.String("linux"), + }, + } + mgpResp = modifyGlobalPolicyResponse{} + s.DoJSON("PATCH", fmt.Sprintf("/api/latest/fleet/policies/%d", gpResp.Policy.ID), mgpParams, http.StatusOK, &mgpResp) + require.NotNil(t, gpResp.Policy) + assert.Equal(t, "TestQuery4", mgpResp.Policy.Name) + assert.Equal(t, "select * from users;", mgpResp.Policy.Query) + assert.Equal(t, "Some description updated", mgpResp.Policy.Description) + require.NotNil(t, mgpResp.Policy.Resolution) + assert.Equal(t, "some global resolution updated", *mgpResp.Policy.Resolution) + assert.Equal(t, "linux", mgpResp.Policy.Platform) + assert.Equal(t, uint(0), mgpResp.Policy.FailingHostCount) + assert.Equal(t, uint(0), mgpResp.Policy.PassingHostCount) + + // Fetch policy to make sure stats are updated + s.DoJSON("GET", "/api/latest/fleet/policies", nil, http.StatusOK, &policiesResponse) + require.Len(t, policiesResponse.Policies, 1) + assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount) + assert.Equal(t, uint(0), policiesResponse.Policies[0].PassingHostCount) + + listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID) + listHostsResp = listHostsResponse{} + s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) + require.Len(t, listHostsResp.Hosts, 0) + + listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=failing", policiesResponse.Policies[0].ID) + listHostsResp = listHostsResponse{} + s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) + require.Len(t, listHostsResp.Hosts, 0) + deletePolicyParams := deleteGlobalPoliciesRequest{IDs: []uint{policiesResponse.Policies[0].ID}} deletePolicyResp := deleteGlobalPoliciesResponse{} s.DoJSON("POST", "/api/latest/fleet/policies/delete", deletePolicyParams, http.StatusOK, &deletePolicyResp) diff --git a/server/service/team_policies.go b/server/service/team_policies.go index 7786c7fe68..75cbe3ae96 100644 --- a/server/service/team_policies.go +++ b/server/service/team_policies.go @@ -368,7 +368,8 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f }) } - var shouldRemoveAll bool + var removeAllMemberships bool + var removeStats bool if p.Name != nil { policy.Name = *p.Name } @@ -377,9 +378,8 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f } if p.Query != nil { if policy.Query != *p.Query { - shouldRemoveAll = true - policy.FailingHostCount = 0 - policy.PassingHostCount = 0 + removeAllMemberships = true + removeStats = true } policy.Query = *p.Query } @@ -387,6 +387,9 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f policy.Resolution = p.Resolution } if p.Platform != nil { + if policy.Platform != *p.Platform { + removeStats = true + } policy.Platform = *p.Platform } if p.Critical != nil { @@ -395,9 +398,13 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f if p.CalendarEventsEnabled != nil { policy.CalendarEventsEnabled = *p.CalendarEventsEnabled } + if removeStats { + policy.FailingHostCount = 0 + policy.PassingHostCount = 0 + } logging.WithExtras(ctx, "name", policy.Name, "sql", policy.Query) - err = svc.ds.SavePolicy(ctx, policy, shouldRemoveAll) + err = svc.ds.SavePolicy(ctx, policy, removeAllMemberships, removeStats) if err != nil { return nil, ctxerr.Wrap(ctx, err, "saving policy") } diff --git a/server/service/team_policies_test.go b/server/service/team_policies_test.go index e6079f1101..9e1a502f67 100644 --- a/server/service/team_policies_test.go +++ b/server/service/team_policies_test.go @@ -44,7 +44,7 @@ func TestTeamPoliciesAuth(t *testing.T) { } return nil, nil } - ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool) error { + ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool, removePolicyStats bool) error { return nil } ds.DeleteTeamPoliciesFunc = func(ctx context.Context, teamID uint, ids []uint) ([]uint, error) {